diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml index 79de7c7d63540..18315e55e854d 100644 --- a/.config/1espt/PipelineAutobaseliningConfig.yml +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -18,7 +18,7 @@ pipelines: credscan: lastModifiedDate: 2025-02-06 binskim: - lastModifiedDate: 2025-02-06 + lastModifiedDate: 2025-04-25 spotbugs: lastModifiedDate: 2025-02-06 usedNonDefaultBranch: true @@ -39,7 +39,7 @@ pipelines: credscan: lastModifiedDate: 2024-10-25 binskim: - lastModifiedDate: 2024-10-25 + lastModifiedDate: 2025-04-25 spotbugs: lastModifiedDate: 2024-10-25 1625: @@ -59,7 +59,7 @@ pipelines: credscan: lastModifiedDate: 2024-11-13 binskim: - lastModifiedDate: 2024-11-13 + lastModifiedDate: 2025-04-25 spotbugs: lastModifiedDate: 2024-11-13 1626: @@ -82,3 +82,97 @@ pipelines: lastModifiedDate: 2024-11-13 spotbugs: lastModifiedDate: 2024-11-13 + 995: + retail: + source: + credscan: + lastModifiedDate: 2025-02-12 + eslint: + lastModifiedDate: 2025-02-12 + psscriptanalyzer: + lastModifiedDate: 2025-02-12 + armory: + lastModifiedDate: 2025-02-12 + 1313: + retail: + source: + credscan: + lastModifiedDate: 2025-02-27 + eslint: + lastModifiedDate: 2025-02-27 + psscriptanalyzer: + lastModifiedDate: 2025-02-27 + armory: + lastModifiedDate: 2025-02-27 + binary: + credscan: + lastModifiedDate: 2025-02-27 + binskim: + lastModifiedDate: 2025-04-25 + spotbugs: + lastModifiedDate: 2025-02-27 + 1312: + retail: + source: + credscan: + lastModifiedDate: 2025-02-27 + eslint: + lastModifiedDate: 2025-02-27 + psscriptanalyzer: + lastModifiedDate: 2025-02-27 + armory: + lastModifiedDate: 2025-02-27 + 841: + retail: + source: + credscan: + lastModifiedDate: 2025-04-24 + eslint: + lastModifiedDate: 2025-04-24 + psscriptanalyzer: + lastModifiedDate: 2025-04-24 + armory: + lastModifiedDate: 2025-04-24 + binary: + credscan: + lastModifiedDate: 2025-04-25 + binskim: + lastModifiedDate: 2025-04-25 + spotbugs: + lastModifiedDate: 2025-04-25 + 1757: + retail: + source: + credscan: + lastModifiedDate: 2025-04-25 + eslint: + lastModifiedDate: 2025-04-25 + psscriptanalyzer: + lastModifiedDate: 2025-04-25 + armory: + lastModifiedDate: 2025-04-25 + binary: + credscan: + lastModifiedDate: 2025-04-25 + binskim: + lastModifiedDate: 2025-04-25 + spotbugs: + lastModifiedDate: 2025-04-25 + 1234: + retail: + source: + credscan: + lastModifiedDate: 2025-04-25 + eslint: + lastModifiedDate: 2025-04-25 + psscriptanalyzer: + lastModifiedDate: 2025-04-25 + armory: + lastModifiedDate: 2025-04-25 + binary: + credscan: + lastModifiedDate: 2025-04-25 + binskim: + lastModifiedDate: 2025-04-25 + spotbugs: + lastModifiedDate: 2025-04-25 diff --git a/.config/guardian/.gdnbaselines b/.config/guardian/.gdnbaselines index e976c78bf4a12..7246ad6ba36df 100644 --- a/.config/guardian/.gdnbaselines +++ b/.config/guardian/.gdnbaselines @@ -52,6 +52,363 @@ "createdDate": "2025-02-06 15:53:46Z", "expirationDate": "2025-07-26 16:26:55Z", "justification": "This error is baselined with an expiration date of 180 days from 2025-02-06 16:26:55Z" + }, + "4a0e83898f2607b442b095973eed78a44648024954e8ef2188e6ba3786771620": { + "signature": "4a0e83898f2607b442b095973eed78a44648024954e8ef2188e6ba3786771620", + "alternativeSignatures": [ + "68a14c0d46c7eb93178e545e06454f12e9d1c98e4e9ecbe009e11cd3eff4f682" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "b3a97d9c774e372e8204bdd6abe8e5bc5fdd46799adaeb0f666838042fdd8f99": { + "signature": "b3a97d9c774e372e8204bdd6abe8e5bc5fdd46799adaeb0f666838042fdd8f99", + "alternativeSignatures": [ + "79f7ba0d65b89c586e684239b8488e9d8cf58c53ce9fedb323dbee5096378e08" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime.dll", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "9836b7915d2d6019bd245fd42935c164e0f9db8c867c8fe26cd6c2f576927aa2": { + "signature": "9836b7915d2d6019bd245fd42935c164e0f9db8c867c8fe26cd6c2f576927aa2", + "alternativeSignatures": [ + "1081cda6083ef05dfe233be730c16c300ef887ba2b0bd8908b2f255ba666f3be" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_autoep_test.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "5d44b5ab5f1d37e0368a85b56bf7460c97e329d626954edd761ace4890e5b6d9": { + "signature": "5d44b5ab5f1d37e0368a85b56bf7460c97e329d626954edd761ace4890e5b6d9", + "alternativeSignatures": [ + "acaec6efd0587a90c9d9ae4dc9893cd8ae4cec325f0263eb710444ca16f76e86" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_perf_test.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "55ad483be6385b220d416d13e2815669297a2e9abb23fb12c4c8057cba628808": { + "signature": "55ad483be6385b220d416d13e2815669297a2e9abb23fb12c4c8057cba628808", + "alternativeSignatures": [ + "64e9ad74cf4e0c10e1bc5c07140f0b8ed101b123a96565b112ee81ccaa3a942a" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_test_all.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "20edb5d58ae9fe1784a826b3cd03ec897234691bcb117b076ee0b1da3c4c4468": { + "signature": "20edb5d58ae9fe1784a826b3cd03ec897234691bcb117b076ee0b1da3c4c4468", + "alternativeSignatures": [ + "dc64851ea95bd49082445a0f63dd2eb843410d661388ff5b64303f71aed44ad4" + ], + "target": "file:///D:/a/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnx_test_runner.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-04-25 11:09:32Z", + "expirationDate": "2025-10-12 11:33:39Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 11:33:39Z" + }, + "6a04730bbf5429b2a82eec02edf9a10199d16a2956ef3bc4b11795318e44d7b3": { + "signature": "6a04730bbf5429b2a82eec02edf9a10199d16a2956ef3bc4b11795318e44d7b3", + "alternativeSignatures": [ + "fae795cbfb44b883a7ef789394c107474257e3256a60fc4f8142634875d2df47" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/onnxruntime_perf_test.exe", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "6b12997cff49c14a6472fa08cd6aae83767f6fea72aa2654caca0a7afbe1e8aa": { + "signature": "6b12997cff49c14a6472fa08cd6aae83767f6fea72aa2654caca0a7afbe1e8aa", + "alternativeSignatures": [ + "4349b8b6474089bf80e75d9ef3e41d904f91caac91fb5b64914ccbad1758e874" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/onnx_test_runner.exe", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "5e3b583bf904b2985056f9b98153ac5be048864ce31c78219e143727c834d988": { + "signature": "5e3b583bf904b2985056f9b98153ac5be048864ce31c78219e143727c834d988", + "alternativeSignatures": [ + "0edba8b6930210ec72d5d7ae32f20150656e16110734e5325665571f873b8787" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnCpu.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "32c1e4332e93d446403d313dee13b68d84e9ae65b45dd624bb07bea27c211b03": { + "signature": "32c1e4332e93d446403d313dee13b68d84e9ae65b45dd624bb07bea27c211b03", + "alternativeSignatures": [ + "55642f7417832dd59f767472b0d50492e273411880395c8fb16ccff013022271" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnGpu.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "79ad215ffb5f1a3c51f4061b877f5d1d65eb71a5c1acedd9fe0410b20f797df1": { + "signature": "79ad215ffb5f1a3c51f4061b877f5d1d65eb71a5c1acedd9fe0410b20f797df1", + "alternativeSignatures": [ + "570709e99467a902da4d446c6dbda5cebc52f84e6272834ec813204060d76188" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnHtp.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "00e3517a665b3fc013745cb76d0aad5d6091828c0325d5a3230fdae837f8f971": { + "signature": "00e3517a665b3fc013745cb76d0aad5d6091828c0325d5a3230fdae837f8f971", + "alternativeSignatures": [ + "f40aab5b7b16c3fbd0562d4d395897370f5409befabb9cb10be9c4a5116001ed" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnHtpPrepare.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "341396cfcbf2cd00c8b505c6fdd4fd56cd45c2466e548c6bfca2ded31ebbc371": { + "signature": "341396cfcbf2cd00c8b505c6fdd4fd56cd45c2466e548c6bfca2ded31ebbc371", + "alternativeSignatures": [ + "0515e21e0cc4f70161d08fe02bf41a1f7818db8ee22a152c550eb75019db2eee" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnHtpV68Stub.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "2fba96257861f1ba2b702bc8c1e9242ef9c71e90529d6dd830e78ae0280b5a50": { + "signature": "2fba96257861f1ba2b702bc8c1e9242ef9c71e90529d6dd830e78ae0280b5a50", + "alternativeSignatures": [ + "e0f5edd83dd731f4774d381587a99f7aa7bbff367931474d2a03d2fd0c284888" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnHtpV73Stub.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "f92f8f9eab05869e80b374408af4f30697bc49c6ff79511e6833cc5dcc0eaf9c": { + "signature": "f92f8f9eab05869e80b374408af4f30697bc49c6ff79511e6833cc5dcc0eaf9c", + "alternativeSignatures": [ + "48b6850d8ea65e1dc6028afba9530341e22be187eb497c791f06c592e06b070e" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnSaver.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "688b86bbaa9d662204784f37c9c184e532a5005429eb654dd31221b27affa28a": { + "signature": "688b86bbaa9d662204784f37c9c184e532a5005429eb654dd31221b27affa28a", + "alternativeSignatures": [ + "42fc3288553ee6bed51b3c0d5e1010b0d7d2efbaf805bc467b99e9b472fb5382" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-arm64/native/QnnSystem.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "0c8d7037690ad318b1d094798d596dba3882a310dc68f97b3815a89d80c540a7": { + "signature": "0c8d7037690ad318b1d094798d596dba3882a310dc68f97b3815a89d80c540a7", + "alternativeSignatures": [ + "b450b5782f69aae532807936726d9c4af88806b0f49ae0ae33cd38a76bdc8a16" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-x64/native/onnxruntime_perf_test.exe", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "16744137fb39b77c89d7a4770f34e8ff4de6a0b94a080be2930daff16a4e4728": { + "signature": "16744137fb39b77c89d7a4770f34e8ff4de6a0b94a080be2930daff16a4e4728", + "alternativeSignatures": [ + "d6b6fc769d11c6c5998c3f7a4ad9aa0458f4ec7d83ed4e633f66a0f605820536" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-x64/native/onnx_test_runner.exe", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "17f65b2847b5e2d9db98da9877df029306d696896a7792cc307d3cd9db05d1b2": { + "signature": "17f65b2847b5e2d9db98da9877df029306d696896a7792cc307d3cd9db05d1b2", + "alternativeSignatures": [ + "bc5561d1f532f05ea310e9086f4755cbc3f30e8d7aa28ae62b127fd6d8332850" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-x64/native/QnnCpu.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "6eab05d067b751e37396a947650b38645d3b18c54065e9fba45c39c140e829b5": { + "signature": "6eab05d067b751e37396a947650b38645d3b18c54065e9fba45c39c140e829b5", + "alternativeSignatures": [ + "c8a0788fe4624f8c8c5e81f6adfb2ef29b4d801fe7d970fc3961aa6958f7f0f3" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-x64/native/QnnHtp.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "dee40d055d13313aec18920bc2f0a0a72bfc6deb836fd7a0a92245fc501bb73c": { + "signature": "dee40d055d13313aec18920bc2f0a0a72bfc6deb836fd7a0a92245fc501bb73c", + "alternativeSignatures": [ + "70da4f93be9acc7617d47270be9c790c970f0c12d6eaa642739d0d1b413eb7b5" + ], + "target": "Microsoft.ML.OnnxRuntime.QNN.1.23.0-dev-20250425-0443-1e118d6/runtimes/win-x64/native/QnnSaver.dll", + "memberOf": [ + "default" + ], + "tool": "codesign", + "ruleId": "CodeSign.MissingSigningCert", + "createdDate": "2025-04-25 17:06:52Z", + "expirationDate": "2025-10-12 17:09:03Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 17:09:03Z" + }, + "568199e534a7aaaa1f2f1395eba1edad215e4a4b1f71457c84868644ca2b5997": { + "signature": "568199e534a7aaaa1f2f1395eba1edad215e4a4b1f71457c84868644ca2b5997", + "alternativeSignatures": [], + "target": "ScanTelemetry_20250425220821141.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2025-04-25 22:25:47Z", + "expirationDate": "2025-10-12 23:01:19Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 23:01:19Z" + }, + "46cd93fda3b6f7eec5d6c3f72169c3ad1092216fa6c9630177d7f2d6e940e555": { + "signature": "46cd93fda3b6f7eec5d6c3f72169c3ad1092216fa6c9630177d7f2d6e940e555", + "alternativeSignatures": [], + "target": "ScanTelemetry_20250425220811012.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2025-04-25 22:25:52Z", + "expirationDate": "2025-10-12 23:01:19Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 23:01:19Z" + }, + "919e797b88639882a09d577e57832df0158ac312a221697faf9fc9a40021d8a0": { + "signature": "919e797b88639882a09d577e57832df0158ac312a221697faf9fc9a40021d8a0", + "alternativeSignatures": [], + "target": "ScanTelemetry_20250425220803349.json", + "line": 1, + "memberOf": [ + "default" + ], + "tool": "credscan", + "ruleId": "CSCAN-AZURE0130", + "createdDate": "2025-04-25 22:25:55Z", + "expirationDate": "2025-10-12 23:01:19Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 23:01:19Z" } } } \ No newline at end of file diff --git a/.github/actions/locate-vcvarsall-and-setup-env/action.yml b/.github/actions/locate-vcvarsall-and-setup-env/action.yml index 2fe3658b465c0..3066721e797ea 100644 --- a/.github/actions/locate-vcvarsall-and-setup-env/action.yml +++ b/.github/actions/locate-vcvarsall-and-setup-env/action.yml @@ -14,10 +14,10 @@ runs: steps: - name: Setup VCPKG - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.5 + uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 with: - vcpkg-version: '2025.03.19' - vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' cmake-version: '3.31.6' cmake-hash: '0f1584e8666cf4a65ec514bd02afe281caabf1d45d2c963f3151c41484f457386aa03273ab25776a670be02725354ce0b46f3a5121857416da37366342a833a0' add-cmake-to-path: 'true' diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index a456d8b036e7b..69ff9a1cec976 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -13,10 +13,111 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + packages: write + attestations: write + id-token: write jobs: + AndroidBinarySizeCheckJob_MinimalBaseline: + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Android NDK + uses: ./.github/actions/setup-android-ndk + with: + ndk-version: 28.0.13004108 + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Set variables from config file + id: set_vars + run: | + import json, os + + config_file_path = "tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config" + with open(config_file_path, mode="r") as config_file: + config = json.load(config_file) + + def set_var(name, value): + print(f"Setting variable: {name} = '{value}'") + # Use GITHUB_ENV for setting environment variables + with open(os.environ['GITHUB_ENV'], 'a') as f: + f.write(f"{name}={value}\n") + + set_var("BuildConfigType", config["type"]) + set_var("BuildConfigOs", config["os"]) + shell: python + working-directory: ${{ github.workspace }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: 1a. Build onnxruntime + run: | + set -e -x + BINARY_SIZE_THRESHOLD_ARGS="" + echo "Binary size threshold in bytes: 1306224" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1306224" + + # Ensure ANDROID_NDK_HOME is available and get its real path + if [ -z "$ANDROID_NDK_HOME" ]; then + echo "ANDROID_NDK_HOME is not set." + exit 1 + fi + NDK_HOME_REALPATH=$(realpath $ANDROID_NDK_HOME) + + # Ensure ANDROID_HOME is available + if [ -z "$ANDROID_HOME" ]; then + echo "ANDROID_HOME is not set. Using default /usr/local/lib/android/sdk" + export ANDROID_HOME=/usr/local/lib/android/sdk + fi + + docker run -e SYSTEM_COLLECTIONURI --rm \ + --volume ${{ github.workspace }}:/onnxruntime_src \ + --volume ${{ runner.temp }}:/build \ + --volume $ANDROID_HOME:/android_home \ + --volume $NDK_HOME_REALPATH:/ndk_home \ + -w /onnxruntime_src \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ + -e NIGHTLY_BUILD=1 \ + -e BUILD_BUILDNUMBER=${{ github.run_number }} \ + -e BUILD_SOURCEVERSION=${{ github.sha }} \ + -e BUILD_ID=${{ github.run_id }} \ + -e BUILD_REASON=${{ github.event_name }} \ + -e BUILD_BRANCH=${{ github.ref }} \ + -e ACTIONS_CACHE_URL \ + -e ACTIONS_RUNTIME_TOKEN \ + -e RUNNER_TEMP=/build \ + ${{ steps.build_docker_image_step.outputs.full-image-name }} \ + bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && \ + python3 tools/ci_build/github/linux/ort_minimal/build_ort_and_check_binary_size.py \ + --build_dir /build/1a \ + ${BINARY_SIZE_THRESHOLD_ARGS} \ + tools/ci_build/github/linux/ort_minimal/build_check_binsize_config/android_minimal_baseline.config" + shell: bash + android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: @@ -30,10 +131,10 @@ jobs: architecture: x64 - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 with: - vcpkg-version: '2025.03.19' - vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' cmake-version: '3.31.6' cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' @@ -71,7 +172,9 @@ jobs: - name: Build Minimal ORT with NNAPI and run tests - run: tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh "$(pwd)" + run: + tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh + "$(pwd)" shell: bash - name: Install psutil for emulator shutdown by run_android_emulator.py diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index b0ba518242aa8..e53626d879dd1 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -42,7 +42,7 @@ jobs: with: python-version: "3.12" architecture: ${{ env.buildArch }} - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 with: vcpkg-version: '2025.03.19' vcpkg-hash: '17e96169cd3f266c4716fcdc1bb728e6a64f103941ece463a2834d50694eba4fb48f30135503fd466402afa139abc847ef630733c442595d1c34979f261b0114' diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 96475c8313793..d3a54e1506e39 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -50,7 +50,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.5 + - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -93,7 +93,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml new file mode 100644 index 0000000000000..e68ef56cdb1ce --- /dev/null +++ b/.github/workflows/linux_minimal_build.yml @@ -0,0 +1,659 @@ +name: Linux CPU Minimal Build E2E + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref + || github.sha }} + cancel-in-progress: true + +env: + BUILD_SOURCES_DIRECTORY: ${{ github.workspace }} + +jobs: + # Job 1: Build full onnxruntime and generate ORT format test files + build_full_ort: + name: 1. Build Full ORT and Generate ORT Files + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: + contents: read + packages: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + with: + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + cmake-version: '3.31.6' + cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' + add-cmake-to-path: 'true' + disable-terrapin: 'true' + + - name: Build Full ORT and Prepare Test Files + uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.6 + + - name: Upload Test Data Artifact + uses: actions/upload-artifact@v4 + with: + name: test_data + path: ${{ runner.temp }}/minimal_build_test_data/ + if-no-files-found: error # Fail if test data wasn't generated + + # Job 2: Build minimal onnxruntime [exceptions DISABLED, type reduction DISABLED, training ops ENABLED] + build_minimal_exceptions_disabled: + name: 2. Build Minimal (Exceptions Disabled) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Run Build 2 (Update) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: Debug # From original --config Debug + mode: 'update' # CMake configure step + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --skip_tests + --minimal_build + --disable_exceptions + --enable_training_ops + + - name: Run Build 2 (Build) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: Debug # From original --config Debug + mode: 'build' # Actual build step + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --skip_tests + --minimal_build + --disable_exceptions + --enable_training_ops + + # Job 3a: Build minimal onnxruntime [exceptions ENABLED, type reduction DISABLED, custom ops ENABLED] and run tests + build_minimal_custom_ops: + name: 3a. Build Minimal (Custom Ops) + needs: build_full_ort # Depends on Job 1 for test data + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + with: + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + cmake-version: '3.31.6' + cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' + add-cmake-to-path: 'true' + disable-terrapin: 'true' + + - name: Build Full ORT and Prepare Test Files + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + with: + reduced-ops-config-file: required_ops.ort_models.config + enable-custom-ops: 'true' + binary-size-report-name-prefix: "3a" + + # Job 3b: Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED] and run tests + build_minimal_type_reduction: + name: 3b. Build Minimal (Type Reduction) + needs: build_full_ort # Depends on Job 1 for test data + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + with: + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + cmake-version: '3.31.6' + cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' + add-cmake-to-path: 'true' + disable-terrapin: 'true' + - name: Build Full ORT and Prepare Test Files + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + with: + reduced-ops-config-file: required_ops_and_types.ort_models.config + enable-type-reduction: 'true' + binary-size-report-name-prefix: "3b" + + # Job 4: Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED (globally allowed types)] and run tests + build_minimal_globally_allowed_types: + name: 4. Build Minimal (Globally Allowed Types) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.6 + with: + vcpkg-version: '2025.04.09' + vcpkg-hash: '31a28b58854b7c7b503db99bb2eb41582d9f835b401adf3bd0f680ef329faa4ab4278b987b586a2a6141e2c98f007833266a4e3b60c3164226a3905466a082ce' + cmake-version: '3.31.6' + cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' + add-cmake-to-path: 'true' + disable-terrapin: 'true' + + - name: Build Full ORT and Prepare Test Files + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.6 + with: + globally_allowed_types: 'bool,float,int8_t,uint8_t' + enable-type-reduction: 'true' + skip-model-tests: 'true' + binary-size-report-name-prefix: "4" + + # Job 5: Build extended minimal onnxruntime and run tests + build_extended_minimal: + name: 5. Build Extended Minimal + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Run Build 5 (Update) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: Debug + mode: 'update' + extra_build_flags: >- + --cmake_generator Ninja + --build_shared_lib + --use_binskim_compliant_compile_flags + --minimal_build extended + + - name: Run Build 5 (Build) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: Debug + mode: 'build' + extra_build_flags: >- + --cmake_generator Ninja + --build_shared_lib + --use_binskim_compliant_compile_flags + --minimal_build extended + - name: Run Build 5 (Test) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: Debug + mode: 'test' + extra_build_flags: >- + --cmake_generator Ninja + --build_shared_lib + --use_binskim_compliant_compile_flags + --minimal_build extended + + # Job 6a: Regular build with python and all optional features disabled. + build_regular_no_optional: + name: 6a. Build Regular (No Optional Features) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: gen config + shell: bash + run: | + mkdir -p ${{ runner.temp }}/.test_data + touch ${{ runner.temp }}/.test_data/include_no_operators.config + + - name: Run Build 6a (Update) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel + mode: 'update' + extra_build_flags: >- + --cmake_generator Ninja + --build_wheel + --use_binskim_compliant_compile_flags + --disable_ml_ops + --disable_types sparsetensor float8 optional + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + - name: Run Build 6a (Build) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel + mode: 'build' + extra_build_flags: >- + --cmake_generator Ninja + --build_wheel + --use_binskim_compliant_compile_flags + --disable_ml_ops + --disable_types sparsetensor float8 optional + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + + - name: Run Build 6a (Test) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel + mode: 'test' + extra_build_flags: >- + --cmake_generator Ninja + --build_wheel + --use_binskim_compliant_compile_flags + --disable_ml_ops + --disable_types sparsetensor float8 optional + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + # Job 6b: Minimal build with all optional features disabled. + build_minimal_no_optional: + name: 6b. Build Minimal (No Optional Features) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: gen config + shell: bash + run: | + mkdir -p ${{ runner.temp }}/.test_data + touch ${{ runner.temp }}/.test_data/include_no_operators.config + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Run Build 6b (Update) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel # From original --config MinSizeRel + mode: 'update' + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --minimal_build + --disable_exceptions + --disable_ml_ops + --skip_tests + --enable_reduced_operator_type_support + --disable_types sparsetensor optional float8 + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + - name: Run Build 6b (Build) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel # From original --config MinSizeRel + mode: 'build' + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --minimal_build + --disable_exceptions + --disable_ml_ops + --skip_tests + --enable_reduced_operator_type_support + --disable_types sparsetensor optional float8 + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + # Job 6c: Extended minimal build with all optional features disabled. + build_extended_minimal_no_optional: + name: 6c. Build Extended Minimal (No Optional Features) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: gen config + shell: bash + run: | + mkdir -p ${{ runner.temp }}/.test_data + touch ${{ runner.temp }}/.test_data/include_no_operators.config + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: gen config + shell: bash + run: | + mkdir -p ${{ runner.temp }}/.test_data + touch ${{ runner.temp }}/.test_data/include_no_operators.config + + - name: Run Build 6c (Update) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel # From original --config MinSizeRel + mode: 'update' + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --minimal_build extended + --disable_exceptions + --disable_ml_ops + --skip_tests + --enable_reduced_operator_type_support + --disable_types sparsetensor optional float8 + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + - name: Run Build 6c (Build) + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 + with: + docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name + }} + build_config: MinSizeRel # From original --config MinSizeRel + mode: 'build' + extra_build_flags: >- + --cmake_generator Ninja + --use_binskim_compliant_compile_flags + --minimal_build extended + --disable_exceptions + --disable_ml_ops + --skip_tests + --enable_reduced_operator_type_support + --disable_types sparsetensor optional float8 + --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF + + # Job 7: Extended minimal build with NNAPI EP for Android(arm64-v8a) and skip tests. + # NOTE: Keeping this as direct docker run due to custom volume mounts needed for Android SDK/NDK + build_extended_minimal_android: + name: 7. Build Extended Minimal (Android NNAPI) + needs: build_full_ort # Depends on Job 1 for test data + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] + permissions: # Permissions needed for build-docker-image + contents: read + packages: write + id-token: write # If using OIDC for ACR login + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - uses: actions/setup-node@v4 + with: + node-version: 20 + - name: Download Test Data Artifact + uses: actions/download-artifact@v4 + with: + name: test_data + path: ${{ runner.temp }}/.test_data/ + + - name: Get Docker Image using Action + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 + id: build_docker_image_step + with: + dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecpubuildcix64 + push: true + azure-container-registry-name: onnxruntimebuildcache + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Setup Android NDK + uses: ./.github/actions/setup-android-ndk + with: + ndk-version: 28.0.13004108 + # Use default android-sdk-root if not specified + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Run Build 7 (Using docker run) + shell: bash + run: | + # Create the target dir for build output inside the runner's temp dir first + mkdir -p ${{ runner.temp }}/7 + + # Ensure ANDROID_NDK_HOME is available and get its real path + if [ -z "$ANDROID_NDK_HOME" ]; then + echo "ANDROID_NDK_HOME is not set." + exit 1 + fi + NDK_HOME_REALPATH=$(realpath $ANDROID_NDK_HOME) + + # Ensure ANDROID_HOME is available + if [ -z "$ANDROID_HOME" ]; then + echo "ANDROID_HOME is not set. Using default /usr/local/lib/android/sdk" + export ANDROID_HOME=/usr/local/lib/android/sdk + fi + + docker run --rm \ + --volume ${{ env.BUILD_SOURCES_DIRECTORY }}:/onnxruntime_src \ + --volume ${{ runner.temp }}:/build \ + --volume $ANDROID_HOME:/android_home \ + --volume $NDK_HOME_REALPATH:/ndk_home \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ + -e NIGHTLY_BUILD=1 -e ACTIONS_CACHE_URL -e ACTIONS_RUNTIME_TOKEN -e RUNNER_TEMP=/build \ + ${{ steps.build_docker_image_step.outputs.full-image-name }} \ + bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt \ + && python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build/7 \ + --cmake_generator Ninja \ + --config MinSizeRel \ + --skip_submodule_sync \ + --parallel --use_binskim_compliant_compile_flags \ + --android \ + --android_sdk_path /android_home \ + --android_ndk_path /ndk_home \ + --android_abi=arm64-v8a \ + --android_api=29 \ + --use_nnapi \ + --minimal_build extended \ + --build_shared_lib \ + --disable_ml_ops \ + --disable_exceptions \ + --skip_tests" + working-directory: ${{ env.BUILD_SOURCES_DIRECTORY }} diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 7fb3786d35b93..f8d4a0d4dd218 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -52,7 +52,7 @@ jobs: # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.5 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -95,7 +95,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index 7ff9260558ebf..27595254800f9 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -83,7 +83,7 @@ jobs: python-version: ${{ inputs.python_version }} - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.5 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.6 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} @@ -103,7 +103,7 @@ jobs: # ------------- Update Step (CMake Generation) ------------- - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: update_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -115,7 +115,7 @@ jobs: # ------------- Build Step (Compilation) ------------- - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: build_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -128,7 +128,7 @@ jobs: - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: test_step if: inputs.run_tests == true - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.5 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.6 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index adf2aed801480..42ecf84369b6f 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -3,13 +3,13 @@ name: windows_x64_asan on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -18,33 +18,33 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env # Use the composite action - with: - architecture: x64 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test (Combined) - shell: cmd - run: | - @echo off - echo %PATH% - python -m pip install -r "%GITHUB_WORKSPACE%\tools\ci_build/github/windows\python\requirements.txt" - python "%GITHUB_WORKSPACE%\tools\ci_build\build.py" --config Debug --build_dir "%RUNNER_TEMP%\build" --skip_submodule_sync --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator "Visual Studio 17 2022" --disable_memleak_checker --enable_address_sanitizer + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env # Use the composite action + with: + architecture: x64 + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test (Combined) + shell: cmd + run: | + @echo off + echo %PATH% + python -m pip install -r "%GITHUB_WORKSPACE%\tools\ci_build/github/windows\python\requirements.txt" + python "%GITHUB_WORKSPACE%\tools\ci_build\build.py" --config Debug --build_dir "%RUNNER_TEMP%\build" --skip_submodule_sync --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_generator "Visual Studio 17 2022" --disable_memleak_checker --enable_address_sanitizer diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 0687bf0a2529d..826c19f31e7e4 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -90,7 +90,7 @@ jobs: - name: Use Nuget 6.x uses: nuget/setup-nuget@v2 with: - nuget-version: '6.x' + nuget-version: '6.x' - name: NuGet restore run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ runner.temp }}\build\RelWithDebInfo diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index 57656d951d40a..c526311036dbe 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -76,7 +76,7 @@ jobs: - name: Use Nuget 6.x uses: nuget/setup-nuget@v2 with: - nuget-version: '6.x' + nuget-version: '6.x' - name: NuGet restore run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ github.workspace }}\RelWithDebInfo diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 04d252ebcba19..f38fcdae57a35 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index cf1658b390fad..e65d23069ad32 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -46,8 +46,8 @@ jobs: AZCOPY_AUTO_LOGIN_TYPE: MSI AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - - name: Download TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8 - run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' + - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 + run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' shell: pwsh env: AZCOPY_AUTO_LOGIN_TYPE: MSI @@ -67,18 +67,18 @@ jobs: Write-Host "CUDA Path: $env:RUNNER_TEMP\v12.2\bin" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\bin" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" - Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8\lib" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" - name: Generate sln working-directory: ${{ runner.temp }} run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 shell: cmd - name: Build working-directory: ${{ runner.temp }} run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 shell: cmd - name: Add build dir to PATH @@ -96,6 +96,6 @@ jobs: working-directory: ${{ runner.temp }} run: | mklink /D /J ${{ github.workspace }}\RelWithDebInfo\models ${{ github.workspace }}\models - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.8.0.43.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 shell: cmd timeout-minutes: 180 diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 8b3b8a2fcde54..999025f560674 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index 04f19ff8664e3..f4c865efe52f1 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -2,13 +2,13 @@ name: windows_x64_debug on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,117 +17,117 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env # Use the composite action - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\Debug\Debug - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 # Use the official NuGet setup action - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\Debug -ConfigFile ${{ github.workspace }}\NuGet.config - - - uses: actions/cache@v4 - id: onnx-node-tests-cache - with: - path: ${{ github.workspace }}/js/test/ - key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config Debug --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_java --build_nodejs --build_wheel --disable_memleak_checker --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\Debug" -Include "*.obj" -Recurse - env: # Set environment variables here, applies to this step only - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - DocUpdateNeeded: 'false' # Can be set dynamically based on build output if needed - - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\Debug\\Debug" + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env # Use the composite action + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\Debug\Debug + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 # Use the official NuGet setup action + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\Debug -ConfigFile ${{ github.workspace }}\NuGet.config + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config Debug --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_java --build_nodejs --build_wheel --disable_memleak_checker --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\Debug" -Include "*.obj" -Recurse + env: # Set environment variables here, applies to this step only + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: 'false' # Can be set dynamically based on build output if needed + + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Install onnxruntime wheel + shell: pwsh + run: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq + Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} + working-directory: "${{ github.workspace }}\\build\\Debug\\Debug" # Publish artifacts only on failure and if DocUpdateNeeded is true (example) - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' # Use env. for step-level vars - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' # Use env. for step-level vars + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md # These variables will persist for the entire job env: diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index 3ff2ec88e8464..cf4e725d9495e 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -2,13 +2,13 @@ name: windows_x64_release on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,115 +17,115 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - uses: actions/cache@v4 - id: onnx-node-tests-cache - with: - path: ${{ github.workspace }}/js/test/ - key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --build_java --build_nodejs --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - DocUpdateNeeded: 'false' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" - - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --build_java --build_nodejs --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: 'false' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Install onnxruntime wheel + shell: pwsh + run: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq + Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} + working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" + + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md env: OrtPackageId: Microsoft.ML.OnnxRuntime diff --git a/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml b/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml index 915688ead150a..4c74505ad183d 100644 --- a/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_dnnl_build_x64_release.yml @@ -2,13 +2,13 @@ name: windows_x64_dnnl_release on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,114 +17,114 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - uses: actions/cache@v4 - id: onnx-node-tests-cache - with: - path: ${{ github.workspace }}/js/test/ - key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --build_java --build_nodejs --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_dnnl - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" - - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --build_java --build_nodejs --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_dnnl + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Install onnxruntime wheel + shell: pwsh + run: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq + Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} + working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" + + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md env: OrtPackageId: Microsoft.ML.OnnxRuntime diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index 491d06f151cd0..76a6203c4dc76 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -2,13 +2,13 @@ name: windows_x64_release_ep_generic_interface on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,100 +17,100 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --update --build --enable_generic_interface --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --update --build --test --enable_generic_interface --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md env: OrtPackageId: Microsoft.ML.OnnxRuntime diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 67f25b655bf04..f95706764d345 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -2,13 +2,13 @@ name: windows_x64_release_vitisai on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,97 +17,97 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --update --build --build_wheel --use_vitisai --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --update --build --build_wheel --use_vitisai --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Install onnxruntime wheel + shell: pwsh + run: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} + working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }}\build diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index 354185cb228a0..e4ee10b691984 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -2,13 +2,13 @@ name: windows_x64_release_xnnpack on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,102 +17,102 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x64 - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x64 - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x64 - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x64 - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --use_xnnpack --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --disable_rtti --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - DocUpdateNeeded: 'false' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x64 + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --use_xnnpack --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --disable_rtti --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: 'false' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md env: OrtPackageId: Microsoft.ML.OnnxRuntime diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index d4b1bf4cfeee0..507eacf21cc5a 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -2,13 +2,13 @@ name: Windows CPU CI Pipeline on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main ] + branches: [main] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -17,110 +17,110 @@ jobs: timeout-minutes: 300 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.12' - architecture: x86 # x86 Python - - - name: Locate vcvarsall and Setup Env - uses: ./.github/actions/locate-vcvarsall-and-setup-env - with: - architecture: x86 # x86 architecture for vcvarsall - - - name: Install python modules - shell: cmd - run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: '20.x' - architecture: x86 #Add architecture - - - name: Setup Java - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '17' - architecture: x86 # x86 Java - - - name: API Documentation Check and generate - shell: cmd - run: | - set ORT_DOXY_SRC=${{ github.workspace }} - set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo - mkdir %ORT_DOXY_SRC% - mkdir %ORT_DOXY_OUT% - "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg - working-directory: ${{ github.workspace }} - - - name: Use .NET 8.x - uses: actions/setup-dotnet@v4 - with: - dotnet-version: '8.x' - env: - PROCESSOR_ARCHITECTURE: x86 # x86 .NET - - - name: Use Nuget 6.x - uses: nuget/setup-nuget@v2 - with: - nuget-version: '6.x' - - - name: NuGet restore - shell: cmd - run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Build and Test - shell: pwsh - run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache - if ($LASTEXITCODE -ne 0) { - exit $LASTEXITCODE - } - Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse - env: - ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' - DocUpdateNeeded: 'false' - - - name: Validate C# native delegates - shell: cmd - run: python tools\ValidateNativeDelegateAttributes.py - working-directory: ${{ github.workspace }}\\csharp - - - name: Install onnxruntime wheel - shell: pwsh - run: | - python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq - Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} - working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" - - - name: Publish OperatorKernels.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: OperatorKernels.md - path: ${{ github.workspace }}/docs/OperatorKernels.md - - - name: Publish ContribOperators.md (Conditional) - uses: actions/upload-artifact@v4 - if: failure() && env.DocUpdateNeeded == 'true' - with: - name: ContribOperators.md - path: ${{ github.workspace }}/docs/ContribOperators.md + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x86 # x86 Python + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x86 # x86 architecture for vcvarsall + + - name: Install python modules + shell: cmd + run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: '20.x' + architecture: x86 #Add architecture + + - name: Setup Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x86 # x86 Java + + - name: API Documentation Check and generate + shell: cmd + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ github.workspace }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + + - name: Use .NET 8.x + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.x' + env: + PROCESSOR_ARCHITECTURE: x86 # x86 .NET + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + shell: cmd + run: | + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Build and Test + shell: pwsh + run: | + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache + if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE + } + Remove-Item "${{ github.workspace }}\build\RelWithDebInfo" -Include "*.obj" -Recurse + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: 'false' + + - name: Validate C# native delegates + shell: cmd + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\\csharp + + - name: Install onnxruntime wheel + shell: pwsh + run: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml -qq + Get-ChildItem -Path dist/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname} + working-directory: "${{ github.workspace }}\\build\\RelWithDebInfo\\RelWithDebInfo" + + - name: Publish OperatorKernels.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: OperatorKernels.md + path: ${{ github.workspace }}/docs/OperatorKernels.md + + - name: Publish ContribOperators.md (Conditional) + uses: actions/upload-artifact@v4 + if: failure() && env.DocUpdateNeeded == 'true' + with: + name: ContribOperators.md + path: ${{ github.workspace }}/docs/ContribOperators.md env: OrtPackageId: Microsoft.ML.OnnxRuntime diff --git a/.lintrunner.toml b/.lintrunner.toml index 74744277fa1e3..2bb6048ae4bea 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -91,22 +91,6 @@ init_command = [ ] is_formatter = true -[[linter]] -code = 'RUSTFMT' -include_patterns = ['**/*.rs'] -command = [ - 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'rustfmt_linter', - '--binary=rustfmt', - '--config-path=rust/rustfmt.toml', - '--', - '@{{PATHSFILE}}' -] -is_formatter = true - [[linter]] code = 'CLANGFORMAT' include_patterns = [ diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cb54bd02d5500..0204ce1423bbf 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -65,6 +65,7 @@ option(onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE "Use a custom SDL Ru option(onnxruntime_ENABLE_PYTHON "Enable python bindings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) +option(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER "Experimental: Enable ConvSymKernelAvx2 assembly saturation checker in build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) # Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through # OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical @@ -856,7 +857,6 @@ endif() set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) -set(ORT_EXTRA_INTERFACE_FLAGS) if (onnxruntime_USE_CUDA) enable_language(CUDA) @@ -912,7 +912,7 @@ if (onnxruntime_USE_CUDA) endif() if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA)) - list(APPEND ORT_EXTRA_INTERFACE_FLAGS -DUSE_CUDA=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUDA_PROVIDER_INTERFACE=1) endif() if (onnxruntime_USE_VITISAI) @@ -921,7 +921,7 @@ if (onnxruntime_USE_VITISAI) endif() if (onnxruntime_USE_VITISAI_INTERFACE AND (NOT onnxruntime_USE_VITISAI)) - list(APPEND ORT_EXTRA_INTERFACE_FLAGS -DUSE_VITISAI=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI_PROVIDER_INTERFACE=1) endif() if (onnxruntime_USE_DNNL) @@ -935,7 +935,7 @@ if (onnxruntime_USE_OPENVINO) endif() if (onnxruntime_USE_OPENVINO_INTERFACE AND (NOT onnxruntime_USE_OPENVINO)) - list(APPEND ORT_EXTRA_INTERFACE_FLAGS -DUSE_OPENVINO=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_OPENVINO_PROVIDER_INTERFACE=1) endif() if (onnxruntime_USE_TENSORRT) @@ -945,7 +945,7 @@ if (onnxruntime_USE_TENSORRT) endif() if (onnxruntime_USE_TENSORRT_INTERFACE AND (NOT onnxruntime_USE_TENSORRT)) - list(APPEND ORT_INTERFACE_FLAGS -DUSE_TENSORRT=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_TENSORRT_PROVIDER_INTERFACE=1) endif() if (onnxruntime_USE_NV) @@ -954,7 +954,7 @@ if (onnxruntime_USE_NV) endif() if (onnxruntime_USE_NV_INTERFACE AND (NOT onnxruntime_USE_NV)) - list(APPEND ORT_INTERFACE_FLAGS -DUSE_NV=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_NV_PROVIDER_INTERFACE=1) endif() if (onnxruntime_USE_RKNPU) @@ -978,7 +978,7 @@ if (onnxruntime_USE_QNN OR onnxruntime_USE_QNN_INTERFACE) if(onnxruntime_USE_QNN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_QNN=1) else() - list(APPEND ORT_EXTRA_INTERFACE_FLAGS -DUSE_QNN=1) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_QNN_PROVIDER_INTERFACE=1) endif() list(APPEND ONNXRUNTIME_PROVIDER_NAMES qnn) @@ -1242,12 +1242,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG}) endforeach() - if("${target_name}" STREQUAL "onnxruntime") - foreach(ORT_EXTRA_FLAG ${ORT_EXTRA_INTERFACE_FLAGS}) - target_compile_definitions(${target_name} PRIVATE ${ORT_EXTRA_FLAG}) - endforeach() - endif() - if (HAS_DEPRECATED_COPY) #too many such errors in eigen target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options -Wno-deprecated-copy>" "$<$:-Wno-deprecated-copy>") @@ -1540,11 +1534,6 @@ if (onnxruntime_USE_OPENVINO) endif() -if (onnxruntime_USE_OPENVINO_INTERFACE AND (NOT onnxruntime_USE_OPENVINO)) - add_definitions(-DUSE_OPENVINO=1) - add_definitions(-DOPENVINO_CONFIG_NPU=1) -endif() - if (onnxruntime_USE_VITISAI) set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_LIST_DIR}") endif() @@ -1745,7 +1734,7 @@ if(VERSION_MAJOR_PART STREQUAL "0" AND VERSION_MINOR_PART STREQUAL "0" AND VERSI list(GET ORT_VERSION_STRING_LIST 0 VERSION_MAJOR_PART) list(GET ORT_VERSION_STRING_LIST 1 VERSION_MINOR_PART) list(GET ORT_VERSION_STRING_LIST 2 VERSION_BUILD_PART) - set(VERSION_STRING ORT_VERSION) + set(VERSION_STRING ${ORT_VERSION}) endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index 71218fd049afb..a10bede254007 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -35,7 +35,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.17.0.zip;13a60ac5217c104139ce0fd024f48628e7bcf5bc +onnx;https://github.com/onnx/onnx/archive/7fc2b81a275223f5b02a522d9d2649837542a7be.zip;555338a12903941bb45f57540476244f9ffee17b # Use the latest commit of 10.9-GA onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/d5dce67db7c2e64b07e055571f5ec06f7f254de2.zip;01114d3b67650857281fa50faa2e412130a63b69 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index b8baa84466864..7fc2b81a27522 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit b8baa8446686496da4cc8fda09f2b6fe65c2a02c +Subproject commit 7fc2b81a275223f5b02a522d9d2649837542a7be diff --git a/cmake/external/onnx_minimal.cmake b/cmake/external/onnx_minimal.cmake deleted file mode 100644 index 65ff3fb148b11..0000000000000 --- a/cmake/external/onnx_minimal.cmake +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -# -# Setup onnx and onnx_protobuf for a build with onnxruntime_MINIMAL_BUILD enabled. -# We exclude everything but the essentials from the onnx library. -# - -if(NOT onnxruntime_MINIMAL_BUILD) - message(FATAL_ERROR "This file should only be included in a minimal build") -endif() - -#TODO: if protobuf is a shared lib and onnxruntime_USE_FULL_PROTOBUF is ON, then onnx_proto should be built as a shared lib instead of a static lib. Otherwise any code outside onnxruntime.dll can't use onnx protobuf definitions if they share the protobuf.dll with onnxruntime. For example, if protobuf is a shared lib and onnx_proto is a static lib then onnxruntime_perf_test won't work. - - - -FetchContent_Populate(onnx) -set(ONNX_SOURCE_ROOT ${onnx_SOURCE_DIR}) - - - -add_library(onnx_proto ${ONNX_SOURCE_ROOT}/onnx/onnx-ml.proto ${ONNX_SOURCE_ROOT}/onnx/onnx-operators-ml.proto ${ONNX_SOURCE_ROOT}/onnx/onnx-data.proto) - -target_include_directories(onnx_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") -target_compile_definitions(onnx_proto PUBLIC $) - -set(_src_prefix "onnx/") -onnxruntime_protobuf_generate(NO_SRC_INCLUDES GEN_SRC_PREFIX ${_src_prefix} IMPORT_DIRS ${ONNX_SOURCE_ROOT} TARGET onnx_proto) - -# For reference, this would be the full ONNX source include. We only need data_type_utils in this build. -# file(GLOB_RECURSE onnx_src CONFIGURE_DEPENDS -# "${ONNX_SOURCE_ROOT}/onnx/*.h" -# "${ONNX_SOURCE_ROOT}/onnx/*.cc" -# ) -# file(GLOB_RECURSE onnx_exclude_src CONFIGURE_DEPENDS -# "${ONNX_SOURCE_ROOT}/onnx/py_utils.h" -# "${ONNX_SOURCE_ROOT}/onnx/proto_utils.h" -# "${ONNX_SOURCE_ROOT}/onnx/backend/test/cpp/*" -# "${ONNX_SOURCE_ROOT}/onnx/test/*" -# "${ONNX_SOURCE_ROOT}/onnx/cpp2py_export.cc" -# ) -# list(REMOVE_ITEM onnx_src ${onnx_exclude_src}) -set(onnx_src - "${ONNX_SOURCE_ROOT}/onnx/common/common.h" - "${ONNX_SOURCE_ROOT}/onnx/defs/data_type_utils.h" - "${ONNX_SOURCE_ROOT}/onnx/defs/data_type_utils.cc" -) - -add_library(onnx ${onnx_src}) -add_dependencies(onnx onnx_proto) -target_include_directories(onnx PUBLIC "${ONNX_SOURCE_ROOT}") -target_include_directories(onnx PUBLIC $) -if (onnxruntime_USE_FULL_PROTOBUF) - target_compile_definitions(onnx PUBLIC "ONNX_ML" "ONNX_NAMESPACE=onnx") -else() - target_compile_definitions(onnx PUBLIC "ONNX_ML" "ONNX_NAMESPACE=onnx" "ONNX_USE_LITE_PROTO") -endif() - -if (WIN32) - target_compile_options(onnx PRIVATE - /wd4800 # 'type' : forcing value to bool 'true' or 'false' (performance warning) - /wd4125 # decimal digit terminates octal escape sequence - /wd4100 # 'param' : unreferenced formal parameter - /wd4244 # 'argument' conversion from 'google::protobuf::int64' to 'int', possible loss of data - /wd4996 # 'argument' Using double parameter version instead of single parameter version of SetTotalBytesLimit(). The second parameter is ignored. - ) - if (NOT onnxruntime_DISABLE_EXCEPTIONS) - target_compile_options(onnx PRIVATE - /EHsc # exception handling - C++ may throw, extern "C" will not - ) - endif() - - target_compile_options(onnx_proto PRIVATE - /wd4244 # 'argument' conversion from 'google::protobuf::int64' to 'int', possible loss of data - ) - - set(onnx_static_library_flags - -IGNORE:4221 # LNK4221: This object file does not define any previously undefined public symbols, so it will not be used by any link operation that consumes this library - ) - set_target_properties(onnx PROPERTIES - STATIC_LIBRARY_FLAGS "${onnx_static_library_flags}") -else() - if(HAS_UNUSED_PARAMETER) - target_compile_options(onnx PRIVATE "-Wno-unused-parameter") - endif() - if(HAS_UNUSED_BUT_SET_VARIABLE) - target_compile_options(onnx PRIVATE "-Wno-unused-but-set-variable") - endif() -endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 8decca10937ba..5d46ac9adb7c2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -46,6 +46,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/rotary_embedding.h ${MLAS_SRC_DIR}/rotary_embedding.cpp ${MLAS_SRC_DIR}/softmax.h + ${MLAS_SRC_DIR}/saturation_check.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -239,6 +240,10 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm PROPERTIES COMPILE_FLAGS "-DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + if(MSVC_VERSION GREATER_EQUAL 1933) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm @@ -637,6 +642,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp @@ -716,6 +722,10 @@ endif() set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") endif() + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index b688e61f53915..6a7510a5d83bc 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -122,6 +122,7 @@ if (onnxruntime_REDUCED_OPS_BUILD) substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) endif() + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and # added to the lib onnxruntime_providers_cuda separately. @@ -129,10 +130,30 @@ set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + + set(onnxruntime_providers_cuda_all_srcs ${cuda_provider_interface_src}) + if(WIN32) + # Sets the DLL version info on Windows: https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource + list(APPEND onnxruntime_providers_cuda_all_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/onnxruntime_providers_cuda.rc") + endif() + + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_all_srcs} + $) else() - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + set(onnxruntime_providers_cuda_all_srcs ${onnxruntime_providers_cuda_src}) + if(WIN32) + # Sets the DLL version info on Windows: https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource + list(APPEND onnxruntime_providers_cuda_all_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/onnxruntime_providers_cuda.rc") + endif() + + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_all_srcs}) + endif() + + if(WIN32) + # FILE_NAME preprocessor definition is used in onnxruntime_providers_cuda.rc + target_compile_definitions(onnxruntime_providers_cuda PRIVATE FILE_NAME=\"onnxruntime_providers_cuda.dll\") endif() + # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. # This function guarantees that all 3 targets have the same configurations. function(config_cuda_provider_shared_module target) diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index 1bba3a0d503c5..60b3aaf38cd85 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -13,6 +13,27 @@ "${ONNXRUNTIME_ROOT}/core/providers/qnn/*.cc" ) + function(extract_qnn_sdk_version_from_yaml QNN_SDK_YAML_FILE QNN_VERSION_OUTPUT) + file(READ "${QNN_SDK_YAML_FILE}" QNN_SDK_YAML_CONTENT) + # Match a line of text like "version: 1.33.2" + string(REGEX MATCH "(^|\n|\r)version: ([0-9]+\\.[0-9]+\\.[0-9]+)" QNN_VERSION_MATCH "${QNN_SDK_YAML_CONTENT}") + if(QNN_VERSION_MATCH) + set(${QNN_VERSION_OUTPUT} "${CMAKE_MATCH_2}" PARENT_SCOPE) + message(STATUS "Extracted QNN SDK version ${CMAKE_MATCH_2} from ${QNN_SDK_YAML_FILE}") + else() + message(WARNING "Failed to extract QNN SDK version from ${QNN_SDK_YAML_FILE}") + endif() + endfunction() + + if(NOT QNN_SDK_VERSION) + if(EXISTS "${onnxruntime_QNN_HOME}/sdk.yaml") + extract_qnn_sdk_version_from_yaml("${onnxruntime_QNN_HOME}/sdk.yaml" QNN_SDK_VERSION) + else() + message(WARNING "Cannot open sdk.yaml to extract QNN SDK version") + endif() + endif() + message(STATUS "QNN SDK version ${QNN_SDK_VERSION}") + if(onnxruntime_BUILD_QNN_EP_STATIC_LIB) # # Build QNN EP as a static library @@ -23,7 +44,7 @@ onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11 - nlohmann_json::nlohmann_json) + nlohmann_json::nlohmann_json) add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) set_target_properties(onnxruntime_providers_qnn PROPERTIES FOLDER "ONNXRuntime") @@ -36,6 +57,21 @@ if(NOT MSVC) target_compile_options(onnxruntime_providers_qnn PRIVATE "-Wno-unknown-pragmas") endif() + + set(onnxruntime_providers_qnn_target onnxruntime_providers_qnn) + + if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + add_custom_command( + TARGET ${onnxruntime_providers_qnn_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ + ) + endif() + if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + add_custom_command( + TARGET ${onnxruntime_providers_qnn_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + ) + endif() else() # # Build QNN EP as a shared library @@ -46,13 +82,20 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) set(onnxruntime_providers_qnn_srcs ${onnxruntime_providers_qnn_ep_srcs} - ${onnxruntime_providers_qnn_shared_lib_srcs}) + ${onnxruntime_providers_qnn_shared_lib_srcs}) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_srcs}) + + set(onnxruntime_providers_qnn_all_srcs ${onnxruntime_providers_qnn_srcs}) + if(WIN32) + # Sets the DLL version info on Windows: https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource + list(APPEND onnxruntime_providers_qnn_all_srcs "${ONNXRUNTIME_ROOT}/core/providers/qnn/onnxruntime_providers_qnn.rc") + endif() + + onnxruntime_add_shared_library_module(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_all_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_qnn ${ONNXRUNTIME_PROVIDERS_SHARED} ${GSL_TARGET} onnx - onnxruntime_common Boost::mp11 safeint_interface - nlohmann_json::nlohmann_json) + onnxruntime_common Boost::mp11 safeint_interface + nlohmann_json::nlohmann_json) target_link_libraries(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS} ${CMAKE_DL_LIBS}) add_dependencies(onnxruntime_providers_qnn onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime_providers_qnn PRIVATE ${ONNXRUNTIME_ROOT} @@ -60,6 +103,18 @@ ${onnxruntime_QNN_HOME}/include/QNN ${onnxruntime_QNN_HOME}/include) + # Set preprocessor definitions used in onnxruntime_providers_qnn.rc + if(WIN32) + if(NOT QNN_SDK_VERSION) + set(QNN_DLL_FILE_DESCRIPTION "ONNX Runtime QNN Provider") + else() + set(QNN_DLL_FILE_DESCRIPTION "ONNX Runtime QNN Provider (QAIRT ${QNN_SDK_VERSION})") + endif() + + target_compile_definitions(onnxruntime_providers_qnn PRIVATE FILE_DESC=\"${QNN_DLL_FILE_DESCRIPTION}\") + target_compile_definitions(onnxruntime_providers_qnn PRIVATE FILE_NAME=\"onnxruntime_providers_qnn.dll\") + endif() + # Set linker flags for function(s) exported by EP DLL if(UNIX) target_link_options(onnxruntime_providers_qnn PRIVATE @@ -90,4 +145,19 @@ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + + set(onnxruntime_providers_qnn_target onnxruntime_providers_qnn) + + if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + add_custom_command( + TARGET ${onnxruntime_providers_qnn_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ + ) + endif() + if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + add_custom_command( + TARGET ${onnxruntime_providers_qnn_target} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + ) + endif() endif() diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 1f7700fa7bc36..59c7db9999b43 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -166,7 +166,14 @@ ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_tensorrt ${onnxruntime_providers_tensorrt_cc_srcs}) + + set(onnxruntime_providers_tensorrt_all_srcs ${onnxruntime_providers_tensorrt_cc_srcs}) + if(WIN32) + # Sets the DLL version info on Windows: https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource + list(APPEND onnxruntime_providers_tensorrt_all_srcs "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/onnxruntime_providers_tensorrt.rc") + endif() + + onnxruntime_add_shared_library_module(onnxruntime_providers_tensorrt ${onnxruntime_providers_tensorrt_all_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common) target_link_libraries(onnxruntime_providers_tensorrt PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -183,6 +190,12 @@ set_target_properties(onnxruntime_providers_tensorrt PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE ONNXIFI_BUILD_LIBRARY=1) target_compile_options(onnxruntime_providers_tensorrt PRIVATE ${DISABLED_WARNINGS_FOR_TRT}) + + if(WIN32) + # FILE_NAME preprocessor definition is used in onnxruntime_providers_tensorrt.rc + target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE FILE_NAME=\"onnxruntime_providers_tensorrt.dll\") + endif() + if (WIN32) target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c0e31990552ea..b31fdd4ea1ee1 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1075,21 +1075,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) endif() - if (onnxruntime_USE_QNN) - if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - add_custom_command( - TARGET ${test_data_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ - ) - endif() - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") - add_custom_command( - TARGET ${test_data_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ - ) - endif() - endif() - if (onnxruntime_USE_DNNL) if(onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND onnxruntime_DNNL_OPENCL_ROOT STREQUAL "") message(FATAL_ERROR "--dnnl_opencl_root required") @@ -1291,7 +1276,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME MATCHES "AIX") list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal) - endif() + endif() target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads) if(WIN32) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) @@ -1301,7 +1286,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") - endif() +endif() if(onnxruntime_USE_QNN) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index c0d9794df160c..c782db4b6d64d 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -1,16 +1,16 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d15d97ed..bdacac99 100644 +index 6fe5c96e..087a7780 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -27,6 +27,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) +@@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) option(ONNX_DISABLE_EXCEPTIONS "Disable exception handling." OFF) - option(ONNX_DISABLE_STATIC_REGISTRATION "Disable static registration for onnx operator schemas." OFF) + option(ONNX_DISABLE_STATIC_REGISTRATION "Disable static registration for ONNX operator schemas." OFF) option(ONNX_USE_UNITY_BUILD "Enable Unity (Jumbo) build for" OFF) +option(ONNX_MINIMAL_BUILD "Build only essential ONNX components" OFF) - - if(NOT DEFINED ONNX_ML) - if(DEFINED ENV{ONNX_ML}) -@@ -457,14 +458,28 @@ relative_protobuf_generate_cpp(gen_onnx_data_proto + if(WIN32) + option(ONNX_USE_MSVC_STATIC_RUNTIME "Build with MSVC static runtime" OFF) + endif() +@@ -461,14 +462,28 @@ relative_protobuf_generate_cpp(gen_onnx_data_proto list(APPEND ONNX_PROTO_SRCS ${__tmp_srcs}) list(APPEND ONNX_PROTO_HDRS ${__tmp_hdrs}) @@ -47,7 +47,7 @@ index d15d97ed..bdacac99 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -496,6 +511,7 @@ if (MSVC) +@@ -492,6 +507,7 @@ if(MSVC) endif() else() # On non-Windows, hide all symbols we don't need @@ -55,10 +55,10 @@ index d15d97ed..bdacac99 100644 set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -631,20 +647,9 @@ endif() - if(MSVC) - target_compile_options(onnx_proto +@@ -620,21 +636,11 @@ if(MSVC) PRIVATE /MP + /wd4146 # unary minus operator applied to unsigned type, + # result still unsigned - /wd4244 #'argument': conversion from 'google:: - #protobuf::uint64' to 'int', possible - # loss of data @@ -67,42 +67,93 @@ index d15d97ed..bdacac99 100644 ${EXTRA_FLAGS}) target_compile_options(onnx PRIVATE /MP + /wd4146 # unary minus operator applied to unsigned type, + # result still unsigned - /wd4244 # 'argument': conversion from 'google:: - # protobuf::uint64' to 'int', possible - # loss of data - /wd4267 # Conversion from 'size_t' to 'int', - # possible loss of data -- /wd4996 # The second parameter is ignored. ${EXTRA_FLAGS}) - if(ONNX_USE_PROTOBUF_SHARED_LIBS) - target_compile_options(onnx_proto -diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h -index b847798e..a6c31904 100644 ---- a/onnx/common/file_utils.h -+++ b/onnx/common/file_utils.h -@@ -6,7 +6,6 @@ + add_msvc_runtime_flag(onnx_proto) + add_msvc_runtime_flag(onnx) +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index 64366270..4aed9027 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -36,7 +36,7 @@ static const char* conv_transpose_auto_pad_doc = + "on whether it is even or odd). In case the padding is an odd number, the extra " + "padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER."; + +-static void convPoolShapeInference( ++void convPoolShapeInference( + InferenceContext& ctx, + bool use_dilation, + bool require_kernel_shape, +@@ -1102,7 +1102,7 @@ ONNX_OPERATOR_SET_SCHEMA( + convPoolShapeInference(ctx, true, false, 0, 1); + })); + +-static void convTransposeShapeInference(InferenceContext& ctx) { ++void convTransposeShapeInference(InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); - #pragma once + // we need at least two inputs to have a shape for this inference. +@@ -1462,7 +1462,7 @@ ONNX_OPERATOR_SET_SCHEMA( + })); --#include - #include - #include + // For GlobalPool operations. +-static void globalPoolTypeShapeInference(InferenceContext& ctx) { ++void globalPoolTypeShapeInference(InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); -@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE { + // needs at least one input with shape. +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index d8ca9a46..1eda4c70 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -4023,7 +4023,6 @@ ONNX_OPERATOR_SET_SCHEMA( + GroupNormalization, + 18, + OpSchema() +- .Deprecate() + .SetDoc(GroupNormalization_ver18_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, 1e-5f) + .Attr( +diff --git a/onnx/defs/rnn/defs.cc b/onnx/defs/rnn/defs.cc +index c0ed3a39..6c8e2909 100644 +--- a/onnx/defs/rnn/defs.cc ++++ b/onnx/defs/rnn/defs.cc +@@ -5,7 +5,7 @@ + #include "onnx/defs/schema.h" - template - void LoadProtoFromPath(const std::string proto_path, T& proto) { -- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path); -- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary); -+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary); - if (!proto_stream.good()) { - fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. "); - } + namespace ONNX_NAMESPACE { +-static void RNNShapeInference(InferenceContext& ctx) { ++void RNNShapeInference(InferenceContext& ctx) { + TensorShapeProto::Dimension num_directions, seq_length, batch_size, hidden_size; + + auto direction = getAttribute(ctx, "direction", "forward"); +diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h +index 42318d82..a33cf342 100644 +--- a/onnx/defs/schema.h ++++ b/onnx/defs/schema.h +@@ -980,10 +980,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { + class OpSchemaRegisterOnce final { + public: + // Export to cpp custom register macro +- explicit OpSchemaRegisterOnce( +- OpSchema op_schema, +- int opset_version_to_load = 0, +- bool fail_duplicate_schema = true) { ++ OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); + } + static void diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..398ac2d6 100644 +index 0aab3e26..27f32195 100644 --- a/onnx/onnx_pb.h +++ b/onnx/onnx_pb.h -@@ -47,10 +47,28 @@ +@@ -47,10 +47,30 @@ #define ONNX_API ONNX_IMPORT #endif @@ -119,6 +170,7 @@ index 0aab3e26..398ac2d6 100644 +#endif // defined(__has_warning) + +#endif // defined(__GNUC__) ++ + #ifdef ONNX_ML #include "onnx/onnx-ml.pb.h" @@ -129,5 +181,6 @@ index 0aab3e26..398ac2d6 100644 +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif ++ + #endif // ! ONNX_ONNX_PB_H diff --git a/cmake/vcpkg-configuration.json b/cmake/vcpkg-configuration.json index 8d819f1c98b1f..54696dc9f2c82 100644 --- a/cmake/vcpkg-configuration.json +++ b/cmake/vcpkg-configuration.json @@ -2,7 +2,7 @@ "default-registry": { "kind": "git", "repository": "https://github.com/Microsoft/vcpkg", - "baseline": "a29711cc86340a43c054cd37b8bd2871332a01e9" + "baseline": "ce613c41372b23b1f51333815feb3edd87ef8a8b" }, "overlay-ports": [ "./vcpkg-ports" diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index c0d9794df160c..c782db4b6d64d 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -1,16 +1,16 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d15d97ed..bdacac99 100644 +index 6fe5c96e..087a7780 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -27,6 +27,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) +@@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) option(ONNX_DISABLE_EXCEPTIONS "Disable exception handling." OFF) - option(ONNX_DISABLE_STATIC_REGISTRATION "Disable static registration for onnx operator schemas." OFF) + option(ONNX_DISABLE_STATIC_REGISTRATION "Disable static registration for ONNX operator schemas." OFF) option(ONNX_USE_UNITY_BUILD "Enable Unity (Jumbo) build for" OFF) +option(ONNX_MINIMAL_BUILD "Build only essential ONNX components" OFF) - - if(NOT DEFINED ONNX_ML) - if(DEFINED ENV{ONNX_ML}) -@@ -457,14 +458,28 @@ relative_protobuf_generate_cpp(gen_onnx_data_proto + if(WIN32) + option(ONNX_USE_MSVC_STATIC_RUNTIME "Build with MSVC static runtime" OFF) + endif() +@@ -461,14 +462,28 @@ relative_protobuf_generate_cpp(gen_onnx_data_proto list(APPEND ONNX_PROTO_SRCS ${__tmp_srcs}) list(APPEND ONNX_PROTO_HDRS ${__tmp_hdrs}) @@ -47,7 +47,7 @@ index d15d97ed..bdacac99 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -496,6 +511,7 @@ if (MSVC) +@@ -492,6 +507,7 @@ if(MSVC) endif() else() # On non-Windows, hide all symbols we don't need @@ -55,10 +55,10 @@ index d15d97ed..bdacac99 100644 set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -631,20 +647,9 @@ endif() - if(MSVC) - target_compile_options(onnx_proto +@@ -620,21 +636,11 @@ if(MSVC) PRIVATE /MP + /wd4146 # unary minus operator applied to unsigned type, + # result still unsigned - /wd4244 #'argument': conversion from 'google:: - #protobuf::uint64' to 'int', possible - # loss of data @@ -67,42 +67,93 @@ index d15d97ed..bdacac99 100644 ${EXTRA_FLAGS}) target_compile_options(onnx PRIVATE /MP + /wd4146 # unary minus operator applied to unsigned type, + # result still unsigned - /wd4244 # 'argument': conversion from 'google:: - # protobuf::uint64' to 'int', possible - # loss of data - /wd4267 # Conversion from 'size_t' to 'int', - # possible loss of data -- /wd4996 # The second parameter is ignored. ${EXTRA_FLAGS}) - if(ONNX_USE_PROTOBUF_SHARED_LIBS) - target_compile_options(onnx_proto -diff --git a/onnx/common/file_utils.h b/onnx/common/file_utils.h -index b847798e..a6c31904 100644 ---- a/onnx/common/file_utils.h -+++ b/onnx/common/file_utils.h -@@ -6,7 +6,6 @@ + add_msvc_runtime_flag(onnx_proto) + add_msvc_runtime_flag(onnx) +diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc +index 64366270..4aed9027 100644 +--- a/onnx/defs/nn/defs.cc ++++ b/onnx/defs/nn/defs.cc +@@ -36,7 +36,7 @@ static const char* conv_transpose_auto_pad_doc = + "on whether it is even or odd). In case the padding is an odd number, the extra " + "padding is added at the end for SAME_UPPER and at the beginning for SAME_LOWER."; + +-static void convPoolShapeInference( ++void convPoolShapeInference( + InferenceContext& ctx, + bool use_dilation, + bool require_kernel_shape, +@@ -1102,7 +1102,7 @@ ONNX_OPERATOR_SET_SCHEMA( + convPoolShapeInference(ctx, true, false, 0, 1); + })); + +-static void convTransposeShapeInference(InferenceContext& ctx) { ++void convTransposeShapeInference(InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); - #pragma once + // we need at least two inputs to have a shape for this inference. +@@ -1462,7 +1462,7 @@ ONNX_OPERATOR_SET_SCHEMA( + })); --#include - #include - #include + // For GlobalPool operations. +-static void globalPoolTypeShapeInference(InferenceContext& ctx) { ++void globalPoolTypeShapeInference(InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); -@@ -17,8 +16,7 @@ namespace ONNX_NAMESPACE { + // needs at least one input with shape. +diff --git a/onnx/defs/nn/old.cc b/onnx/defs/nn/old.cc +index d8ca9a46..1eda4c70 100644 +--- a/onnx/defs/nn/old.cc ++++ b/onnx/defs/nn/old.cc +@@ -4023,7 +4023,6 @@ ONNX_OPERATOR_SET_SCHEMA( + GroupNormalization, + 18, + OpSchema() +- .Deprecate() + .SetDoc(GroupNormalization_ver18_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, 1e-5f) + .Attr( +diff --git a/onnx/defs/rnn/defs.cc b/onnx/defs/rnn/defs.cc +index c0ed3a39..6c8e2909 100644 +--- a/onnx/defs/rnn/defs.cc ++++ b/onnx/defs/rnn/defs.cc +@@ -5,7 +5,7 @@ + #include "onnx/defs/schema.h" - template - void LoadProtoFromPath(const std::string proto_path, T& proto) { -- std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path); -- std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary); -+ std::fstream proto_stream(proto_path, std::ios::in | std::ios::binary); - if (!proto_stream.good()) { - fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. "); - } + namespace ONNX_NAMESPACE { +-static void RNNShapeInference(InferenceContext& ctx) { ++void RNNShapeInference(InferenceContext& ctx) { + TensorShapeProto::Dimension num_directions, seq_length, batch_size, hidden_size; + + auto direction = getAttribute(ctx, "direction", "forward"); +diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h +index 42318d82..a33cf342 100644 +--- a/onnx/defs/schema.h ++++ b/onnx/defs/schema.h +@@ -980,10 +980,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { + class OpSchemaRegisterOnce final { + public: + // Export to cpp custom register macro +- explicit OpSchemaRegisterOnce( +- OpSchema op_schema, +- int opset_version_to_load = 0, +- bool fail_duplicate_schema = true) { ++ OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); + } + static void diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..398ac2d6 100644 +index 0aab3e26..27f32195 100644 --- a/onnx/onnx_pb.h +++ b/onnx/onnx_pb.h -@@ -47,10 +47,28 @@ +@@ -47,10 +47,30 @@ #define ONNX_API ONNX_IMPORT #endif @@ -119,6 +170,7 @@ index 0aab3e26..398ac2d6 100644 +#endif // defined(__has_warning) + +#endif // defined(__GNUC__) ++ + #ifdef ONNX_ML #include "onnx/onnx-ml.pb.h" @@ -129,5 +181,6 @@ index 0aab3e26..398ac2d6 100644 +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif ++ + #endif // ! ONNX_ONNX_PB_H diff --git a/cmake/vcpkg-ports/onnx/fix-cmakelists.patch b/cmake/vcpkg-ports/onnx/fix-cmakelists.patch index 2f5e79c95aff5..2b5e8a0540a91 100644 --- a/cmake/vcpkg-ports/onnx/fix-cmakelists.patch +++ b/cmake/vcpkg-ports/onnx/fix-cmakelists.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index b666eec..66c234d 100644 +index 6fe5c96e..633debb6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -63,6 +63,16 @@ endif() +@@ -70,6 +70,16 @@ endif() include(GNUInstallDirs) @@ -16,16 +16,6 @@ index b666eec..66c234d 100644 + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnx +) + - set(ONNX_ROOT ${PROJECT_SOURCE_DIR}) + set(ONNX_ROOT ${onnx_SOURCE_DIR}) # Read ONNX version -@@ -104,7 +114,8 @@ endif() - # find_package Python has replaced PythonInterp and PythonLibs since cmake 3.12 - # Use the following command in the future; now this is only compatible with the latest pybind11 - # find_package(Python ${PY_VERSION} COMPONENTS Interpreter Development REQUIRED) --find_package(PythonInterp ${PY_VERSION} REQUIRED) -+find_package(Python3 ${PY_VERSION} COMPONENTS Interpreter REQUIRED) -+set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) - if(BUILD_ONNX_PYTHON) - find_package(PythonLibs ${PY_VERSION}) - endif() diff --git a/cmake/vcpkg-ports/onnx/fix-dependency-protobuf.patch b/cmake/vcpkg-ports/onnx/fix-dependency-protobuf.patch index c435922d0103d..ceb50feb1e4b9 100644 --- a/cmake/vcpkg-ports/onnx/fix-dependency-protobuf.patch +++ b/cmake/vcpkg-ports/onnx/fix-dependency-protobuf.patch @@ -1,17 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index d81ac1d..9f97998 100644 +index 6fe5c96e..ae828752 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -149,6 +149,7 @@ if(ONNX_BUILD_TESTS) +@@ -141,6 +141,7 @@ if(ONNX_BUILD_TESTS) set(googletest_STATIC_LIBRARIES GTest::gtest) endif() +find_package(protobuf CONFIG REQUIRED) + if(NOT ONNX_BUILD_CUSTOM_PROTOBUF) if((ONNX_USE_LITE_PROTO AND TARGET protobuf::libprotobuf-lite) OR ((NOT ONNX_USE_LITE_PROTO) AND TARGET protobuf::libprotobuf)) # Sometimes we need to use protoc compiled for host architecture while linking - # libprotobuf against target architecture. See https://github.com/caffe2/caffe diff --git a/cmake/ONNXConfig.cmake.in b/cmake/ONNXConfig.cmake.in -index d588f8a..dbd4398 100644 +index d588f8ae..dbd43986 100644 --- a/cmake/ONNXConfig.cmake.in +++ b/cmake/ONNXConfig.cmake.in @@ -6,9 +6,8 @@ diff --git a/cmake/vcpkg-ports/onnx/portfile.cmake b/cmake/vcpkg-ports/onnx/portfile.cmake index 16c5715483025..0cd6bfa305843 100644 --- a/cmake/vcpkg-ports/onnx/portfile.cmake +++ b/cmake/vcpkg-ports/onnx/portfile.cmake @@ -3,8 +3,8 @@ vcpkg_check_linkage(ONLY_STATIC_LIBRARY) vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO onnx/onnx - REF "v${VERSION}" - SHA512 5a18e2b19ec9c18c8b115fb7e12ed98eddaa581c95f15c4dd420cd6c86e7caa04f9a393da589e76b89cf9b3544abd3749a8c77c2446782f37502eb74e9b1f661 + REF 7fc2b81a275223f5b02a522d9d2649837542a7be + SHA512 6911b4e532a7735ef40660dee904877850234a600b39d46a8dab91f6506c6547e3bd10af5d5f0f0abc0c6e7e6e1fc04c0ea307eb9f4aef5c614eaaa50403804d PATCHES fix-cmakelists.patch fix-dependency-protobuf.patch diff --git a/cmake/vcpkg-ports/onnx/vcpkg.json b/cmake/vcpkg-ports/onnx/vcpkg.json index 8c3cd291e80b1..f0a356ce3bb8b 100644 --- a/cmake/vcpkg-ports/onnx/vcpkg.json +++ b/cmake/vcpkg-ports/onnx/vcpkg.json @@ -1,6 +1,6 @@ { "name": "onnx", - "version-semver": "1.17.0", + "version-semver": "1.18.0", "port-version": 1, "description": "Open standard for machine learning interoperability", "homepage": "https://onnx.ai", diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml new file mode 100644 index 0000000000000..fa0e957418fab --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml @@ -0,0 +1,103 @@ + + + + + + $(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories) + + + $(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories) + + + + + + $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + + + + + + $(MSBuildThisFileDirectory)../../runtimes/win-arm/native/onnxruntime.lib;%(AdditionalDependencies) + + + + + + $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + + + + + + $(MSBuildThisFileDirectory)../../runtimes/win-x86/native/onnxruntime.lib;%(AdditionalDependencies) + + + + + x86 + arm64x + arm + $(Platform) + + + + $(MSBuildThisFileDirectory)..\..\runtimes\win-$(EnginePlatform)\native\onnxruntime.dll + + + + + + onnxruntime.dll + PreserveNewest + false + + + onnxruntime_providers_shared.dll + PreserveNewest + false + + + + + onnxruntime.dll + PreserveNewest + false + + + onnxruntime_providers_shared.dll + PreserveNewest + false + + + + + onnxruntime.dll + PreserveNewest + false + + + onnxruntime_providers_shared.dll + PreserveNewest + false + + + + + onnxruntime.dll + PreserveNewest + false + + + diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index 5822a805c674e..ea28e144ee95a 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -4,18 +4,16 @@ # -------------------------------------------------------------- # Dockerfile to run ONNXRuntime with source build for CPU -FROM mcr.microsoft.com/cbl-mariner/base/python:3 +FROM mcr.microsoft.com/azurelinux/base/python:3 MAINTAINER Changming Sun "chasun@microsoft.com" ADD . /code RUN tdnf install -y tar ca-certificates build-essential cmake curl python3-devel python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf -# The latest cmake version in Mariner2 is 3.21, but we need 3.26+ -RUN /code/dockerfiles/scripts/install_cmake.sh # Prepare onnxruntime repository & build onnxruntime RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sync --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) -FROM mcr.microsoft.com/cbl-mariner/base/python:3 +FROM mcr.microsoft.com/azurelinux/base/python:3 COPY --from=0 /code/build/Linux/Release/dist /root COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl diff --git a/dockerfiles/README.md b/dockerfiles/README.md index 9f83fc390eee7..4c69098103edd 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -35,7 +35,7 @@ The docker file supports both x86_64 and ARM64(aarch64). You may use docker's "--platform" parameter to explicitly specify which CPU architecture you want to build. For example: ```bash - docker build --platform linux/arm64/v8 -f Dockerfile.source + docker build --platform linux/arm64/v8 -f Dockerfile.source .. ``` However, we cannot build the code for 32-bit ARM in such a way since a 32-bit compiler/linker might not have enough memory to generate the binaries. diff --git a/dockerfiles/scripts/install_cmake.sh b/dockerfiles/scripts/install_cmake.sh index e89c323460ac4..6229339251ec4 100755 --- a/dockerfiles/scripts/install_cmake.sh +++ b/dockerfiles/scripts/install_cmake.sh @@ -5,7 +5,7 @@ cd /tmp/src echo "Installing cmake" CPU_ARCH=`uname -m` -CMAKE_VERSION='3.27.3' +CMAKE_VERSION='3.28.0' curl https://github.com/Kitware/CMake/releases/download/v$CMAKE_VERSION/cmake-$CMAKE_VERSION-linux-$CPU_ARCH.tar.gz -sSL --retry 5 -o /tmp/src/cmake.tar.gz tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr rm -f /tmp/src/cmake.tar.gz diff --git a/docs/How_To_Update_ONNX_Dev_Notes.md b/docs/How_To_Update_ONNX_Dev_Notes.md index 997812d7e7acf..895b552508cf6 100644 --- a/docs/How_To_Update_ONNX_Dev_Notes.md +++ b/docs/How_To_Update_ONNX_Dev_Notes.md @@ -4,7 +4,12 @@ This note is only for ONNX Runtime developers. If you need to update the ONNX submodule to a different version, follow the steps below. -1. Update the ONNX submodule +## Update ONNX installation + +Currently, ONNXRUNTIME supports two ways to install ONNX cpp dependencies, one is through cmake/deps.txt, and the other one is by vcpkg. And both of them are guarded by CI. It is recommeded to test vcpkg within Windows machines. + +### Update the ONNX submodule (commit would be more precise than branch) + ```sh cd cmake/external/onnx git remote update @@ -12,34 +17,79 @@ git reset --hard cd .. git add onnx ``` -(Change the to yours. If you are not sure, use 'origin/master'. Like 'git reset --hard origin/master') -1. Update [cgmanifests/generated/cgmanifest.json](/cgmanifests/generated/cgmanifest.json). -This file should be generated. See [cgmanifests/README](/cgmanifests/README.md) for instructions. +(Change the to yours. If you are not sure, use 'origin/main'. Like 'git reset --hard origin/main') + +### Update cmake/deps.txt + +1. Update [cmake/deps.txt](/cmake/deps.txt) with the correct zip download link and SHA (alternatively, build it with the wrong SHA and ORT should tell you the expected one.). +2. Check [cmake/patches/onnx/onnx.patch](/cmake/patches/onnx/onnx.patch) to see whether the diffs are resolved in the latest ONNX version. +3. Try to build ONNXRUNTIME from source. If the build fails, please make the changes accordingly, or use onnx.patch if it's ONNX bugs. An example build: + +```bash +./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.6/ --cudnn_home /usr/local/cuda-12.6/ --build_wheel --parallel --skip_tests +``` + +### Update cmake/vcpkg-ports + +1. Modify [cmake/vcpkg-ports/onnx/binskim.patch](/cmake/vcpkg-ports/onnx/binskim.patch) to be the same as [cmake/patches/onnx/onnx.patch](/cmake/patches/onnx/onnx.patch). +2. The other patches are required/created by vcpkg repository to build ONNX. We just need to re-run diff to makes sure the patches can be applied in the updated ONNX version. +3. Update [cmake/vcpkg-ports/onnx/portfile.cmake](/cmake/vcpkg-ports/onnx/portfile.cmake) with the correct commit id and SHA512. (alternatively, build it with the wrong SHA and ORT should tell you the expected one.) +4. Try to build ONNXRUNTIME from source. If the build fails, please make the changes accordingly, or use binskim.patch if it's ONNX bugs. An example build: + +```bash +./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.6/ --cudnn_home /usr/local/cuda-12.6/ --build_wheel --parallel --skip_tests --use_vcpkg +``` + +## Update ONNX related documentations + +We need to update the auto-generated ONNX kernels markdowns and requirements.txt. + +### AUTO-generated ONNX kernels + +We can either use the following command lines to generate them, or go to CI (AzureDevOps published Artifacts) to download and upload the generated markdowns from CI to update them (suggested). + +If you want to do the command lines: + +1. Update [docs/OperatorKernels.md](/docs/OperatorKernels.md) + +```bash +# under onnxruntime root +python tools/python/gen_opkernel_doc.py --output_path docs/OperatorKernels.md +``` + +1. Update [js/web/docs/webgl-operators.md](/js/web/docs/webgl-operators.md) with the script: [generate-webgl-operator-md.ts](/js/web/script/generate-webgl-operator-md.ts) + +```bash +node /home/titaiwang/onnxruntime/js/web/script/generate-webgl-operator-md.js +``` + +### Update requirements.txt + +Update Python requirements files with the updated ONNX version (e.g., `onnx==1.16.0`) or commit hash if building from source (e.g., `git+http://github.com/onnx/onnx.git@targetonnxcommithash#egg=onnx`). -1. Update Python requirements files with the updated ONNX version (e.g., `onnx==1.16.0`) or commit hash if building from source (e.g., `git+http://github.com/onnx/onnx.git@targetonnxcommithash#egg=onnx`). - [onnxruntime/test/python/requirements.txt](/onnxruntime/test/python/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/requirements.txt](/tools/ci_build/github/linux/docker/scripts/requirements.txt) - [tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt](/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt) - [tools/ci_build/github/linux/python/requirements.txt](/tools/ci_build/github/linux/python/requirements.txt) - Run `git grep -rn "onnx==1" .` to find other locations and update this document if necessary. +## Additional Notes + 1. If there is any change to `cmake/external/onnx/onnx/*.in.proto`, you need to regenerate OnnxMl.cs. [Building onnxruntime with Nuget](https://onnxruntime.ai/docs/build/inferencing.html#build-nuget-packages) will do this. - -1. If you are updating ONNX from a released tag to a new commit, please ask Changming (@snnn) to deploy the new test +2. If you are updating ONNX from a released tag to a new commit, please ask Changming (@snnn) to deploy the new test data along with other test models to our CI build machines. This is to ensure that our tests cover every ONNX opset. - -1. Send your PR, and **manually** queue a build for every packaging pipeline for your branch. - -1. If there is a build failure in stage "Check out of dated documents" in WebAssembly CI pipeline, update ONNX Runtime +3. Send your PR, and **manually** queue a build for every packaging pipeline for your branch. +4. If there is a build failure in stage "Check out of dated documents" in WebAssembly CI pipeline, update ONNX Runtime Web WebGL operator support document: + - Make sure Node.js is installed (see [Prerequisites](../js/README.md#Prerequisites) for instructions). - Follow [js/Build](../js/README.md#Build-2) to install dependencies. - Follow instructions in [Generate document](../js/README.md#Generating-Document) to update document. Commit changes applied to file `docs/operators.md`. +5. Usually some newly introduced tests will fail. Then you may need to update -2. Usually some newly introduced tests will fail. Then you may need to update - [onnxruntime/test/onnx/main.cc](/onnxruntime/test/onnx/main.cc) - [onnxruntime/test/providers/cpu/model_tests.cc](/onnxruntime/test/providers/cpu/model_tests.cc) - [csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs](/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs) @@ -47,5 +97,6 @@ This file should be generated. See [cgmanifests/README](/cgmanifests/README.md) - [onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc](/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc) 1. If an operator has changed we may need to update optimizers involving that operator. + - Run [find_optimizer_opset_version_updates_required.py](/tools/python/find_optimizer_opset_version_updates_required.py), compare with the output from the current main branch, and check for any new warnings. - If there are new warnings contact the optimizer owner (which can usually be determined by looking at who edited the file most recently) or failing that ask the 'ONNX Runtime Shared Core' mailing list. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a20333e2340c4..1c30e67534a0c 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,7 +58,8 @@ Do not modify directly.* |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -76,7 +77,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||20|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|22+|**T** = tensor(float)| @@ -98,7 +100,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T3**|23+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||[21, 22]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)| @@ -129,7 +132,8 @@ Do not modify directly.* |||[8, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |EyeLike|*in* input:**T1**
*out* output:**T2**|22+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)| |||[9, 21]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)
**T2** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint64)| -|Flatten|*in* input:**T**
*out* output:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Flatten|*in* input:**T**
*out* output:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -172,13 +176,15 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[11, 12]|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|21+|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|23+|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 15]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|If|*in* cond:**B**
*out* outputs:**V**|21+|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|If|*in* cond:**B**
*out* outputs:**V**|23+|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**B** = tensor(bool)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 15]|**B** = tensor(bool)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -211,7 +217,8 @@ Do not modify directly.* |LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|21+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Loop|*in* M:**I**
*in* cond:**B**
*in* v_initial:**V**
*out* v_final_and_scan_outputs:**V**|23+|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = optional(seq(tensor(bfloat16))), optional(seq(tensor(bool))), optional(seq(tensor(double))), optional(seq(tensor(float))), optional(seq(tensor(float16))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(int8))), optional(seq(tensor(string))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(uint8))), optional(tensor(bfloat16)), optional(tensor(bool)), optional(tensor(double)), optional(tensor(float)), optional(tensor(float16)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(int8)), optional(tensor(string)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(uint8)), seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 15]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -271,7 +278,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[9, 15]|**T** = tensor(float)| |||[7, 8]|**T** = tensor(float)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|23+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||18|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -285,7 +293,8 @@ Do not modify directly.* |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**TS** = tensor(float)| |||[10, 20]|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**T2**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|23+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| +|||[21, 22]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -346,7 +355,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8)| |||13|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[14, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| @@ -365,7 +375,8 @@ Do not modify directly.* |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |Scale|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| -|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|21+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|23+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[16, 18]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 15]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -388,7 +399,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[15, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -403,7 +415,8 @@ Do not modify directly.* |||[7, 21]|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(float)| |||[9, 21]|**T** = tensor(float)| -|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|23+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[21, 22]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[19, 20]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -427,7 +440,8 @@ Do not modify directly.* |SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)| |Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| -|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -453,12 +467,14 @@ Do not modify directly.* |TopK|*in* X:**T**
*in* K:**tensor(int64)**
*out* Values:**T**
*out* Indices:**I**

or

*in* X:**T**
*out* Values:**T**
*out* Indices:**I**|11+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| -|Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)| |Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)| -|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -630,7 +646,7 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T3**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[13, 18]|**T** = tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int8), tensor(uint8)| @@ -762,7 +778,7 @@ Do not modify directly.* |||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**T2**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e5m2), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -1069,7 +1085,7 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T3**|21+|**T1** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| @@ -1228,7 +1244,7 @@ Do not modify directly.* |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**T2**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int4), tensor(int8), tensor(uint4), tensor(uint8)| |||19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 523d2a9d1a8be..ce7d4aaf652d0 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -41,6 +41,7 @@ struct OrtArenaCfg { namespace onnxruntime { constexpr const char* CPU = "Cpu"; +constexpr const char* CPU_ALIGNED_4K = "CpuAligned4K"; constexpr const char* CUDA = "Cuda"; constexpr const char* CUDA_PINNED = "CudaPinned"; constexpr const char* CANN = "Cann"; @@ -57,6 +58,7 @@ constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; constexpr size_t kAllocAlignment = 256; +constexpr const size_t kAlloc4KAlignment = 4096; class IAllocator; class Stream; @@ -270,4 +272,7 @@ using AllocatorMap = std::map; void* AllocatorDefaultAlloc(size_t size); void AllocatorDefaultFree(void* p); +void* AllocatorDefaultAllocAligned(size_t size, size_t alignment); +void AllocatorDefaultFreeAligned(void* p, size_t alignment); + } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index adade482f6a17..472575d1998f5 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -11,6 +11,7 @@ struct OrtDevice { using DeviceType = int8_t; using MemoryType = int8_t; using DeviceId = int16_t; + using Alignment = size_t; // Pre-defined device types. static const DeviceType CPU = 0; @@ -28,31 +29,40 @@ struct OrtDevice { static const MemoryType QNN_HTP_SHARED = 4; }; - constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) + constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_, Alignment alignment) noexcept : device_type(device_type_), memory_type(memory_type_), - device_id(device_id_) {} + device_id(device_id_), + alignment(alignment) {} - constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {} + constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) noexcept + : OrtDevice(device_type_, memory_type_, device_id_, 0) {} - DeviceType Type() const { + constexpr OrtDevice() noexcept : OrtDevice(CPU, MemType::DEFAULT, 0) {} + + DeviceType Type() const noexcept { return device_type; } - MemoryType MemType() const { + MemoryType MemType() const noexcept { return memory_type; } - DeviceId Id() const { + DeviceId Id() const noexcept { return device_id; } + Alignment GetAlignment() const noexcept { + return alignment; + } + std::string ToString() const { std::ostringstream ostr; ostr << "Device:[" << "DeviceType:" << static_cast(device_type) << " MemoryType:" << static_cast(memory_type) << " DeviceId:" << device_id + << " Alignment:" << alignment << "]"; return ostr.str(); } @@ -62,6 +72,7 @@ struct OrtDevice { auto h = std::hash()(device_type); onnxruntime::HashCombine(memory_type, h); onnxruntime::HashCombine(device_id, h); + onnxruntime::HashCombine(alignment, h); return h; } @@ -71,8 +82,10 @@ struct OrtDevice { return device_type < other.device_type; if (memory_type != other.memory_type) return memory_type < other.memory_type; + if (device_id != other.device_id) + return device_id < other.device_id; - return device_id < other.device_id; + return alignment < other.alignment; } private: @@ -84,6 +97,9 @@ struct OrtDevice { // Device index. int32_t device_id : 16; + + // Required alignment + Alignment alignment; }; inline bool operator==(const OrtDevice& left, const OrtDevice& other) { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5288296fd4750..9a5891f9e236d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -700,6 +700,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi; struct OrtCompileApi; typedef struct OrtCompileApi OrtCompileApi; +struct OrtEpApi; +typedef struct OrtEpApi OrtEpApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -5186,6 +5189,12 @@ struct OrtApi { * \since Version 1.22. */ const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the OrtEpApi instance for implementing an execution provider. + * + * \since Version 1.22. + */ + const OrtEpApi*(ORT_API_CALL* GetEpApi)(); }; /* @@ -5889,6 +5898,29 @@ struct OrtCompileApi { ORT_RUNTIME_CLASS(Ep); ORT_RUNTIME_CLASS(EpFactory); +struct OrtEpApi { + /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. + * \param[in] ep_factory Execution provider factory that is creating the instance. + * \param[in] hardware_device Hardware device that the EP can utilize. + * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used + * during execution provider selection and passed to CreateEp. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added + * to the Session configuration options if the execution provider is selected. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param ep_device OrtExecutionDevice that is created. + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + + ORT_CLASS_RELEASE(EpDevice); +}; + /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -5993,21 +6025,28 @@ struct OrtEpFactory { /** \brief Get information from the execution provider if it supports the OrtHardwareDevice. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] device The OrtHardwareDevice instance. - * \param[out] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used - * during execution provider selection and/or CreateEp. - * \param[out] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added - * to the Session configuration options if the execution provider is selected. + * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. + * \param[in] devices The OrtHardwareDevice instances that are available. + * \param[in] num_devices The number of OrtHardwareDevice instances. + * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. + * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice + * instances to this pre-allocated array. ORT will take ownership of the values returned. + * i.e. usage is `ep_devices[0] = ;` + * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. + * Current default is 8. This can be increased if needed. + * \param[out] num_ep_devices The number of EP devices added to ep_devices. * \return true if the factory can create an execution provider that uses `device`. * * \note ORT will take ownership or ep_metadata and/or ep_options if they are not null. * * \since Version 1.22. */ - bool(ORT_API_CALL* GetDeviceInfoIfSupported)(const OrtEpFactory* this_ptr, - _In_ const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** ep_metadata, - _Out_opt_ OrtKeyValuePairs** ep_options); + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -6015,11 +6054,11 @@ struct OrtEpFactory { * * \param[in] this_ptr The OrtEpFactory instance. * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. - * \param[in] ep_metadata_pairs Execution provider metadata that was returned in GetDeviceInfoIfSupported, for each + * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each * device. * \param[in] num_devices The number of devices the execution provider was selected for. * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the - * session. This will include ep_options from GetDeviceInfoIfSupported as well as any + * session. This will include ep_options from GetSupportedDevices as well as any * user provided overrides. * Execution provider options will have been added with a prefix of 'ep..'. * The OrtSessionOptions instance will NOT be valid after this call and should not be @@ -6029,7 +6068,7 @@ struct OrtEpFactory { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.22. + * \since Version . This is a placeholder. */ OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, _In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -6043,7 +6082,7 @@ struct OrtEpFactory { * \param[in] this_ptr The OrtEpFactory instance. * \param[in] ep The OrtEp instance to release. * - * \since Version 1.22. + * \since Version . This is a placeholder. */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a2937b6e82a27..0ecc27c59dc28 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -172,6 +172,20 @@ inline const OrtCompileApi& GetCompileApi() { return *api; } +/// +/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider. +/// +/// ORT C EP API reference +inline const OrtEpApi& GetEpApi() { + auto* api = GetApi().GetEpApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL); + } + + return *api; +} + /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -559,7 +573,9 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE(Node); ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); +ORT_DEFINE_RELEASE(KeyValuePairs) ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -675,6 +691,7 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; +struct EpDevice; struct Graph; struct Model; struct Node; @@ -737,6 +754,94 @@ struct ThreadingOptions : detail::Base { ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; +namespace detail { +template +struct KeyValuePairsImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + const char* GetValue(const char* key) const; + + // get the pairs in unordered_map. needs to copy to std::string so the hash works as expected + std::unordered_map GetKeyValuePairs() const; + // get the pairs in two vectors. entries will be 1:1 between keys and values. avoids copying to std::string + void GetKeyValuePairs(std::vector& keys, std::vector& values) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstKeyValuePairs = detail::KeyValuePairsImpl>; + +/** \brief Wrapper around ::OrtKeyValuePair */ +struct KeyValuePairs : detail::KeyValuePairsImpl { + explicit KeyValuePairs(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl{p} {} + + /// \brief Wraps OrtApi::CreateKeyValuePairs + explicit KeyValuePairs(); + + /// \brief Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair + explicit KeyValuePairs(const std::unordered_map& kv_pairs); + + /// \brief Wraps OrtApi::AddKeyValuePair + void Add(const char* key, const char* value); + + /// \brief Wraps OrtApi::RemoveKeyValuePair + void Remove(const char* key); + + ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } +}; + +namespace detail { +template +struct HardwareDeviceImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + OrtHardwareDeviceType Type() const; + uint32_t VendorId() const; + uint32_t DeviceId() const; + const char* Vendor() const; + ConstKeyValuePairs Metadata() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtHardwareDevice + * \remarks HardwareDevice is always read-only for API users. + */ +using ConstHardwareDevice = detail::HardwareDeviceImpl>; + +namespace detail { +template +struct EpDeviceImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + const char* EpName() const; + const char* EpVendor() const; + ConstKeyValuePairs EpMetadata() const; + ConstKeyValuePairs EpOptions() const; + ConstHardwareDevice Device() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtEpDevice + * \remarks EpDevice is always read-only for ORT API users. + */ +using ConstEpDevice = detail::EpDeviceImpl>; + +/** \brief Mutable EpDevice that is created by EpApi users. + */ +struct EpDevice : detail::EpDeviceImpl { + explicit EpDevice(std::nullptr_t) {} ///< No instance is created + explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl{p} {} ///< Take ownership of a pointer created by C API + + /// \brief Wraps OrtEpApi::CreateEpDevice + EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, + ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); +}; + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. @@ -768,7 +873,14 @@ struct Env : detail::Base { Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator - Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, + const std::unordered_map& options, + const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary + Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary + + std::vector GetEpDevices() const; }; /** \brief Custom Op Domain @@ -919,7 +1031,7 @@ struct ConstSessionOptionsImpl : Base { std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry - std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); + std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const; }; template @@ -981,6 +1093,11 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); + SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, + const KeyValuePairs& ep_options); + SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, + const std::unordered_map& ep_options); + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e41ef005349ac..48b3b80cced55 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -479,6 +479,125 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom return *this; } +namespace detail { +template +inline const char* KeyValuePairsImpl::GetValue(const char* key) const { + return GetApi().GetKeyValue(this->p_, key); +} + +template +inline std::unordered_map KeyValuePairsImpl::GetKeyValuePairs() const { + std::unordered_map out; + + size_t num_pairs = 0; + const char* const* keys = nullptr; + const char* const* values = nullptr; + GetApi().GetKeyValuePairs(this->p_, &keys, &values, &num_pairs); + if (num_pairs > 0) { + out.reserve(num_pairs); + for (size_t i = 0; i < num_pairs; ++i) { + out.emplace(keys[i], values[i]); + } + } + + return out; +} + +template +inline void KeyValuePairsImpl::GetKeyValuePairs(std::vector& keys, + std::vector& values) const { + keys.clear(); + values.clear(); + + size_t num_pairs = 0; + const char* const* keys_ptr = nullptr; + const char* const* values_ptr = nullptr; + GetApi().GetKeyValuePairs(this->p_, &keys_ptr, &values_ptr, &num_pairs); + if (num_pairs > 0) { + keys.resize(num_pairs); + values.resize(num_pairs); + std::copy(keys_ptr, keys_ptr + num_pairs, keys.begin()); + std::copy(values_ptr, values_ptr + num_pairs, values.begin()); + } +} +} // namespace detail + +inline KeyValuePairs::KeyValuePairs() { + GetApi().CreateKeyValuePairs(&p_); +} + +inline KeyValuePairs::KeyValuePairs(const std::unordered_map& kv_pairs) { + GetApi().CreateKeyValuePairs(&p_); + for (const auto& kv : kv_pairs) { + GetApi().AddKeyValuePair(this->p_, kv.first.c_str(), kv.second.c_str()); + } +} + +inline void KeyValuePairs::Add(const char* key, const char* value) { + GetApi().AddKeyValuePair(this->p_, key, value); +} + +inline void KeyValuePairs::Remove(const char* key) { + GetApi().RemoveKeyValuePair(this->p_, key); +} + +namespace detail { +template +inline OrtHardwareDeviceType HardwareDeviceImpl::Type() const { + return GetApi().HardwareDevice_Type(this->p_); +} + +template +inline uint32_t HardwareDeviceImpl::VendorId() const { + return GetApi().HardwareDevice_VendorId(this->p_); +} + +template +inline uint32_t HardwareDeviceImpl::DeviceId() const { + return GetApi().HardwareDevice_DeviceId(this->p_); +} + +template +inline const char* HardwareDeviceImpl::Vendor() const { + return GetApi().HardwareDevice_Vendor(this->p_); +} + +template +inline ConstKeyValuePairs HardwareDeviceImpl::Metadata() const { + return ConstKeyValuePairs{GetApi().HardwareDevice_Metadata(this->p_)}; +} + +template +inline const char* EpDeviceImpl::EpName() const { + return GetApi().EpDevice_EpName(this->p_); +} + +template +inline const char* EpDeviceImpl::EpVendor() const { + return GetApi().EpDevice_EpVendor(this->p_); +} + +template +inline ConstKeyValuePairs EpDeviceImpl::EpMetadata() const { + return ConstKeyValuePairs(GetApi().EpDevice_EpMetadata(this->p_)); +} + +template +inline ConstKeyValuePairs EpDeviceImpl::EpOptions() const { + return ConstKeyValuePairs(GetApi().EpDevice_EpOptions(this->p_)); +} + +template +inline ConstHardwareDevice EpDeviceImpl::Device() const { + return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_)); +} +} // namespace detail + +inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, + ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) { + ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_)); +} + inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); if (strcmp(logid, "onnxruntime-node") == 0) { @@ -551,6 +670,33 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, return *this; } +inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, + const std::basic_string& path) { + ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str())); + return *this; +} + +inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) { + ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name)); + return *this; +} + +inline std::vector Env::GetEpDevices() const { + size_t num_devices = 0; + const OrtEpDevice* const* device_ptrs = nullptr; + ThrowOnError(GetApi().GetEpDevices(p_, &device_ptrs, &num_devices)); + + std::vector devices; + if (num_devices > 0) { + devices.reserve(num_devices); + for (size_t i = 0; i < num_devices; ++i) { + devices.emplace_back(device_ptrs[i]); + } + } + + return devices; +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } @@ -717,7 +863,8 @@ inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) c } template -inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { +inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, + const std::string& def) const { if (!this->HasConfigEntry(config_key)) { return def; } @@ -955,6 +1102,53 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( return *this; } +namespace { +template +void SessionOptionsAppendEP(detail::SessionOptionsImpl& session_options, + Env& env, const std::vector& ep_devices, + const std::vector& ep_options_keys, + const std::vector& ep_options_values) { + std::vector ep_devices_ptrs; + ep_devices_ptrs.reserve(ep_devices.size()); + for (const auto& ep_device : ep_devices) { + ep_devices_ptrs.push_back(ep_device); + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_V2( + session_options, env, ep_devices_ptrs.data(), ep_devices_ptrs.size(), + ep_options_keys.data(), ep_options_values.data(), ep_options_keys.size())); +} +} // namespace + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( + Env& env, const std::vector& ep_devices, const KeyValuePairs& ep_options) { + std::vector ep_options_keys, ep_options_values; + ep_options.GetKeyValuePairs(ep_options_keys, ep_options_values); + + SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( + Env& env, const std::vector& ep_devices, + const std::unordered_map& ep_options) { + std::vector ep_options_keys, ep_options_values; + ep_options_keys.reserve(ep_options.size()); + ep_options_values.reserve(ep_options.size()); + + for (const auto& [key, value] : ep_options) { + ep_options_keys.push_back(key.c_str()); + ep_options_values.push_back(value.c_str()); + } + + SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index 3aec0aa3d7cf3..0cf282e9c9b12 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -4,7 +4,7 @@ The following table shows [ai.onnx](https://github.com/onnx/onnx/blob/main/docs/ See [Compatibility](../README.md#Compatibility) for a list of the supported platforms. -*This file is automatically generated from the def files via [this script](../script/generate-operator-md.ts). Do not modify directly.* +*This file is automatically generated from the def files via [generate-webgl-operator-md.ts](../script/generate-webgl-operator-md.ts). Do not modify directly.* | Operator | WebGl Backend | |:--------:|:-------------:| @@ -20,6 +20,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Asinh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Asinh) | | | [Atan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atan) | [7-21](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atan-7), [22+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atan-22) | | [Atanh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atanh) | | +| [Attention](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Attention) | | | [AveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool) | [7-9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-7), [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-10), [11-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-11), [19-21](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-19), [22+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-22) | | [BatchNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-7), [9-13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-9), [14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-14), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-15) | | [Bernoulli](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Bernoulli) | | @@ -29,7 +30,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [BitwiseOr](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseOr) | | | [BitwiseXor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseXor) | | | [BlackmanWindow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BlackmanWindow) | | -| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-21) | +| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-23) | | [CastLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CastLike) | | | [Ceil](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Ceil) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-13) | | [Celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu) | | @@ -62,7 +63,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Exp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13) | | [Expand](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand) | | | [EyeLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike) | | -| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-21) | +| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-23) | | [Floor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Floor) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-13) | | [GRU](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU) | | | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | @@ -82,7 +83,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [HardSigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid) | | | [HardSwish](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSwish) | | | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | -| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-21) | +| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-23) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | | [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6-21](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6), [22+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-22) | @@ -124,11 +125,12 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [OptionalHasElement](https://github.com/onnx/onnx/blob/main/docs/Operators.md#OptionalHasElement) | | | [Or](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Or) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7) | | [PRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#PRelu) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-7), [9-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-9), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-16) | -| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-21) | +| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-23) | | [Pow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow) | [7-11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-7), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-12), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-15) | | [QLinearConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearConv) | | | [QLinearMatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearMatMul) | | | [QuantizeLinear](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QuantizeLinear) | | +| [RMSNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RMSNormalization) | | | [RNN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RNN) | | | [RandomNormal](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RandomNormal) | | | [RandomNormalLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RandomNormalLike) | | @@ -148,10 +150,11 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | | [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | -| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-21) | +| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-23) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | | [ReverseSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReverseSequence) | | | [RoiAlign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RoiAlign) | | +| [RotaryEmbedding](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding) | | | [Round](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Round) | | | [STFT](https://github.com/onnx/onnx/blob/main/docs/Operators.md#STFT) | | | [Scan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Scan) | | @@ -166,7 +169,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SequenceInsert](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceInsert) | | | [SequenceLength](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceLength) | | | [SequenceMap](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceMap) | | -| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-21) | +| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-23) | | [Shrink](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shrink) | | | [Sigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13) | | [Sign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sign) | | @@ -182,7 +185,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Split](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-11) | | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | -| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-21) | +| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-23) | | [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | | [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | @@ -194,10 +197,10 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ThresholdedRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ThresholdedRelu) | | | [Tile](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tile) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-13) | | [TopK](https://github.com/onnx/onnx/blob/main/docs/Operators.md#TopK) | | -| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-21) | +| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-23) | | [Trilu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Trilu) | | | [Unique](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unique) | | -| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-21) | +| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13), [21-22](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-21), [23+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-23) | | [Upsample](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Upsample) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-7), [9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-9) | | [Where](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where) | | | [Xor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Xor-7) | diff --git a/js/web/script/generate-webgl-operator-md.ts b/js/web/script/generate-webgl-operator-md.ts index 5cc43eb903527..2bba6cc271068 100644 --- a/js/web/script/generate-webgl-operator-md.ts +++ b/js/web/script/generate-webgl-operator-md.ts @@ -79,7 +79,7 @@ doc.write(`The following table shows [ai.onnx](https://github.com/onnx/onnx/blob ONNX Runtime Web currently support opset version 4 to 6, 8 and above.${EOL}${EOL}`); doc.write(`See [Compatibility](../README.md#Compatibility) for a list of the supported platforms.${EOL}${EOL}`); doc.write(`*This file is automatically generated from the\ - def files via [this script](../script/generate-operator-md.ts).\ + def files via [generate-webgl-operator-md.ts](../script/generate-webgl-operator-md.ts).\ Do not modify directly.*${EOL}${EOL}`); doc.write(`| Operator | WebGl Backend |${EOL}`); doc.write(`|:--------:|:-------------:|${EOL}`); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index d5a6a1ae699d9..8f013a1426ef8 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -111,8 +111,8 @@ class MatMulNBits final : public OpKernel { has_unquantized_zero_point_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; } - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(nbits_ == 4 || nbits_ == 8, + "Only 4b and 8b quantization is supported for MatMulNBits op, additional bits support is planned."); const Tensor* tensor_zero_point = nullptr; has_zp_input_ = info.TryGetConstantInput(InputIndex::zero_points, &tensor_zero_point); } @@ -439,6 +439,8 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const { + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for unpacked compute."); + const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); @@ -547,6 +549,7 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const { + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for unpacked compute."); const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 9ebc44f4411eb..64bd2b7b1855e 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "core/framework/ort_value.h" #include "core/framework/float16.h" #include "contrib_ops/cpu/utils/debug_macros.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu index f334b72e70a34..aabbe4cc7582a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu @@ -100,9 +100,7 @@ bool is_supported(const cudaDeviceProp& dprops, int sequence_length_q, int sequence_length_kv, bool is_causal) { - bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; - bool is_sm90 = dprops.major == 9 && dprops.minor == 0; - return (is_sm8x || is_sm90) && + return (dprops.major >= 8) && (head_size_qk % 8 == 0) && (head_size_qk <= 256) && (head_size_v % 8 == 0) && (head_size_v <= 256) && (num_heads_q % num_heads_kv == 0) && diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 6a3e52bee3995..453dffaa2e6e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -411,9 +411,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, } bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { - bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; - bool is_sm90 = dprops.major == 9 && dprops.minor == 0; - return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); + return (dprops.major >= 8) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); } // This API is used when past key and value are present... since cached, these are assumed to have sequence length diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index a6ea9f4b61271..3e0b9d35b1950 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -1041,7 +1041,7 @@ void initialize_moe_routing_kernelLauncher(const T *unpermuted_input, T *permute #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template __global__ void finalize_moe_routing_kernel(const T *, T *, const T *, const T *, const T *, const T *, const int *, - const int *, int, const int) { + const int *, int, int) { // Does not support pre-Kepler architectures ; } @@ -1168,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, const half *, const int *, const int *, int, int, int, cudaStream_t); -} // namespace ort_fastertransformer \ No newline at end of file +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index 580b5087f3fa3..f6f380c8211f6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,6 +7,43 @@ namespace onnxruntime { namespace contrib { namespace cuda { + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + + using ThreadBlk = Shape2D; +}; + template Status Dequantize4Bits( T* output, @@ -19,6 +56,18 @@ Status Dequantize4Bits( int block_size, cudaStream_t stream); +template +Status Dequantize8Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const ZeroT* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports @@ -45,6 +94,17 @@ Status DequantizeBlockwise4b( int columns, cudaStream_t stream); +template +Status DequantizeBlockwise8b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu similarity index 70% rename from onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu rename to onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index 7fb0619a799dc..cea1834fa1b62 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -1,4 +1,3 @@ -// Modifications: scaling is moved from masked softmax to the gemm before that. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -20,7 +19,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { - __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * zp; @@ -68,25 +66,28 @@ __global__ void Dequantize4BitsKernelReOrder( int groups_per_K, int groups_per_threadblock, int total_groups) { - int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + constexpr int bits = 4; + constexpr int element_per_thread = 32 / bits; // Process 8 elements per thread using uint32_t load + constexpr int element_per_byte = 8 / bits; + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); if (group_id >= total_groups) { return; } - const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int zero_point_shape_x = (groups_per_K + (element_per_byte - 1)) / element_per_byte; const int scales_shape_x = groups_per_K; int n_idx = group_id / scales_shape_x; int kb_idx = group_id % scales_shape_x; - int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + int element_offset = group_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); T* output_i = output + element_offset; - uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); - const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1)); - for (int i = 0; i < 8; i++) { + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / element_per_byte)); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + for (int i = 0; i < element_per_thread; i++) { int32_t rid = reorder_idx_with_off[i]; T scale = *(scale_data + n_idx * scales_shape_x + rid); - uint8_t zp = 8; + uint8_t zp = 8; // Default zero point is 1 << (bits - 1) if (zero_points) { - zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = zero_points[n_idx * zero_point_shape_x + rid / element_per_byte]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); } @@ -130,7 +131,7 @@ __global__ void Dequantize4BitsKernel( } zero_point_value = static_cast(zp); } else { - zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); + zero_point_value = zero_points ? *(zero_points + block_id) : static_cast(8); } output = output + element_offset; @@ -151,35 +152,45 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; int total_groups = n * groups_per_K; // total elemenets in quant_data - int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_grid = CeilDiv(total_groups, groups_per_threadblock); + dim3 grid_dim(groups_per_grid); + dim3 block_dim(GridDim::maxThreadsPerBlock); + if (!reorder_idx || std::is_same_v) { - Dequantize4BitsKernel<<>>( + // Launch standard kernel + Dequantize4BitsKernel<<>>( output, quant_data, scales_data, zero_points, block_size, - groups_per_K, + groups_per_K, // Pass groups_per_K for potential ZP indexing if needed groups_per_threadblock, total_groups); } else { - // static_assert(std::is_same_v, "ZeroT must be uint8_t"); - Dequantize4BitsKernelReOrder<<>>( - output, - quant_data, - scales_data, - (const uint8_t*)zero_points, - reorder_idx, - block_size, - groups_per_K, - groups_per_threadblock, - total_groups); + // Launch reorder kernel (requires uint8_t zero points as per original structure) + if constexpr (std::is_same_v) { + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::INVALID_ARGUMENT, + "Reorder kernel currently expects uint8_t zero points."); + } } - return Status::OK(); + return CUDA_CALL(cudaGetLastError()); // Check for launch errors } template Status Dequantize4Bits( @@ -224,60 +235,23 @@ template Status Dequantize4Bits( int n, int block_size, cudaStream_t stream); -/////////////////////////////////////////////////////////////////////////////// -// A more general block-wise dequantization implementation that supports -// different block sizes and block orientations (row-wise/column-wise). template < - int Row_, ///< rows of a matrix - int Column_ ///< columns of a matrix - > -struct Shape2D { - static int const kRow = Row_; ///< rows of a matrix - static int const kColumn = Column_; ///< columns of a matrix - static int const kCount = Row_ * Column_; ///< total number of elements in a matrix -}; - -/** - * @brief Blockwise quantization constants - * @tparam ElementT source data type, e.g. fp32/fp16 - * @tparam block_size number of elemenets quantized together - * @tparam qbits number of bits in each quantized element - * @tparam Columnwise true: elements in a block come from one single column - * false: elements in a block come from one single row - */ -template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -struct BlkQuantTraits { - // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; - static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); - - using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - using ThreadBlk = Shape2D; -}; - -template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -__global__ -void dequantizeThread(ElementT* dst, - const uint8_t* weights, - const ElementT* scales, - const uint8_t* zero_points, - int rows, - int columns, - int thrd_row_blks) { + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ void dequantizeThread4b(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { using QuantBlk = typename BlkQuantTraits::QuantBlk; using ThreadBlk = typename BlkQuantTraits::ThreadBlk; - // !! 4b specific code - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 4, "Only 4b block quantization is supported by this kernel specialization!!"); const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; @@ -299,12 +273,12 @@ void dequantizeThread(ElementT* dst, // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets const ElementT scale_buf[2] = { scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], - ((r/QuantBlk::kRow) < (meta_rows - 1)) + ((r / QuantBlk::kRow) < (meta_rows - 1)) ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] : static_cast(0.0f)}; const uint8_t zp_pair = (zero_points == nullptr) - ? 0x88 - : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), (-scale_buf[1]) * static_cast(zp_buf[1])}; @@ -315,7 +289,7 @@ void dequantizeThread(ElementT* dst, const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; - const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow]; const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; const auto vi = q_ptr[i / 2]; @@ -333,7 +307,8 @@ void dequantizeThread(ElementT* dst, static_assert(std::is_same::value, "Only float and half are supported!"); const uint8_t vi0 = vi & 0xf; const uint8_t vi1 = vi >> 4; - dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0; + ; dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; } } @@ -351,13 +326,13 @@ void dequantizeThread(ElementT* dst, } template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, - const uint8_t* zero_points, int32_t rows, int32_t columns, - cudaStream_t stream) { + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize4b_generic(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { using ThreadBlk = typename BlkQuantTraits::ThreadBlk; // Thread partitioning @@ -366,7 +341,7 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; - dequantizeThread<<>>( + dequantizeThread4b<<>>( dst, weights, scales, @@ -376,7 +351,6 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc thrd_row_blks); } - template Status DequantizeBlockwise4b( @@ -392,41 +366,37 @@ DequantizeBlockwise4b( switch (block_size) { case 16: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 32: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 64: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 128: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, - columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, - rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 256: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, - columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, - rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); default: @@ -436,8 +406,8 @@ DequantizeBlockwise4b( } } -template -Status DequantizeBlockwise4b( +// Template instantiations for 4-bit blockwise +template Status DequantizeBlockwise4b( float* dst, const uint8_t* src, const float* scales, @@ -448,8 +418,7 @@ Status DequantizeBlockwise4b( int columns, cudaStream_t stream); -template -Status DequantizeBlockwise4b( +template Status DequantizeBlockwise4b( half* dst, const uint8_t* src, const half* scales, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu new file mode 100644 index 0000000000000..e90ed85b22f02 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu @@ -0,0 +1,465 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) +__device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + // Formula: val = (quant - zp) * scale = quant * scale - zp * scale + half zp_adjust = -scale * zp; + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[2]; // Store 4 half values + + // Extract 4 uint8_t values from uint32_t + half v0 = __ushort2half_rn(static_cast(values_quant & 0xFF)); + half v1 = __ushort2half_rn(static_cast((values_quant >> 8) & 0xFF)); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __ushort2half_rn(static_cast((values_quant >> 16) & 0xFF)); + half v3 = __ushort2half_rn(static_cast((values_quant >> 24) & 0xFF)); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + // Write 4 half values (equivalent to float2) + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +// Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) +__device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, float scale, float zp, float* output) { + // Assuming ZP is symmetric or already adjusted if needed. Standard formula: val = (quant - zp) * scale = quant * scale - zp * scale + float zp_adjust = -scale * zp; + + // Extract 4 uint8_t values from uint32_t + output[0] = float(values_quant & 0xFF) * scale + zp_adjust; + output[1] = float((values_quant >> 8) & 0xFF) * scale + zp_adjust; + output[2] = float((values_quant >> 16) & 0xFF) * scale + zp_adjust; + output[3] = float((values_quant >> 24) & 0xFF) * scale + zp_adjust; +} + +// REVIEW: Deprecate reorder_idx (Recommend to reorder scales and zero points during model conversion instead of using reorder_idx). +// Reorder index is a 1D array of size [K] to support desc_act used in GPTQ quantization. +// However, it impacts inference performance of the kernel since it is not optimized for coalescing memory access. +template +__global__ void Dequantize8BitsKernelReOrder( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, // Assuming uint8_t zero points for reorder case + const int32_t* reorder_idx, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + constexpr int element_per_thread = 4; // Process 4 elements (uint8_t) per thread using uint32_t load + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); + if (group_id >= total_groups) { + return; + } + + // element_offset corresponds to the start of the 4 elements processed by this thread iteration + int element_offset = group_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + T* output_i = output + element_offset; + + // shape of scales and zero_points is [N, groups_per_K]. Compute the 2D indices below. + int n_idx = group_id / groups_per_K; + int kb_idx = group_id % groups_per_K; + + // Read 4 uint8_t values packed into a uint32_t + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset)); + + // Adjust reorder index pointer to the start of the 4 indices for this thread iteration + const int32_t* g_idx = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + for (int i = 0; i < element_per_thread; i++) { + // Typical value of g_idx is in the range of [0, groups_per_K) for reordering groups. + // No range check here so it might have out-of-bound access if the reorder_idx is not valid. + int32_t rid = g_idx[i]; + ptrdiff_t scale_zp_offset = n_idx * groups_per_K + rid; + T scale = *(scale_data + scale_zp_offset); + + uint8_t zp = 128; // Default zero point + if (zero_points) { + zp = zero_points[scale_zp_offset]; + } + + // Extract the i-th uint8_t value + uint8_t q_val = (quant_value >> (8 * i)) & 0xFF; + + if constexpr (std::is_same_v) { + T zp_T = __ushort2half_rn(zp); + T zp_adjust = -scale * zp_T; + output_i[i] = __ushort2half_rn(q_val) * scale + zp_adjust; + } else { + T zp_T = static_cast(zp); + T zp_adjust = -scale * zp_T; + output_i[i] = static_cast(q_val) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize8BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_threadblock, + int total_groups) { + constexpr int element_per_thread = 4; // Process 4 elements (uint8_t) per thread using uint32_t load + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); + if (block_id >= total_groups) { + return; + } + + // element_offset corresponds to the start of the 4 elements processed by this thread iteration + int element_offset = block_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + // Read 4 uint8_t values packed into a uint32_t + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset)); + T scale = *(scale_data + block_id); // One scale per block + + T zero_point_value; + if constexpr (std::is_same_v) { + // Assuming one uint8_t zero point per block. Default 128 for uint8 asymmetric. + uint8_t zp = 128; + if (zero_points) { + zp = zero_points[block_id]; // Direct lookup, no packing + } + // Convert uint8_t zp to T (float/half) + if constexpr (std::is_same_v) { + zero_point_value = __uint2half_rn(zp); + } else { + zero_point_value = static_cast(zp); + } + } else { // ZeroT is T (float or half) + // Default 0 for float/half zero point + zero_point_value = zero_points ? *(zero_points + block_id) : static_cast(0.0f); + } + + output = output + element_offset; // Point output to the start of the 4 elements + DequantizeFourElements(quant_value, scale, zero_point_value, output); +} + +template +Status Dequantize8Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const ZeroT* zero_points, // Shape: [N, K_blocks] or [N * K_blocks] + const int32_t* reorder_idx, // If provided, ZeroT is expected to be uint8_t + int k, // Original dimension before padding + int n, // Other dimension + int block_size, + cudaStream_t stream) { + ORT_ENFORCE(k % block_size == 0, "k must be a multiple of block_size"); // K shall be padded to multiple of block_size. + + constexpr int element_per_thread = 4; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // Total number of blocks + + assert(block_size <= GridDim::maxThreadsPerBlock * element_per_thread); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_grid = CeilDiv(total_groups, groups_per_threadblock); + + dim3 grid_dim(groups_per_grid); + dim3 block_dim(GridDim::maxThreadsPerBlock); + + DUMP_TENSOR_INIT(); + if (!reorder_idx || std::is_same_v) { + DUMP_STRING("Launch standard kernel for Dequantize8Bits"); + Dequantize8BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_threadblock, + total_groups); + } else { + if constexpr (std::is_same_v) { + DUMP_STRING("Launch reorder kernel for Dequantize8Bits"); + Dequantize8BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::INVALID_ARGUMENT, + "Reorder kernel currently expects uint8_t zero points."); + } + } + + return CUDA_CALL(cudaGetLastError()); // Check for launch errors +} + +// Template instantiations for 8-bit +template Status Dequantize8Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +// Generic dequantization kernel for 8 bits +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ void dequantizeThread8b(ElementT* dst, + const uint8_t* weights, // Quantized data (uint8_t) + const ElementT* scales, + const uint8_t* zero_points, // Assuming uint8_t zero points + int rows, + int columns, + int thread_row_blocks) { // Number of thread blocks along row dimension + + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + static_assert(qbits == 8, "Only 8b block quantization is supported by this kernel specialization!"); + + const auto thread_idx_global = blockIdx.x * blockDim.x + threadIdx.x; + + // Total blocks along row dim for scales/zp + const auto total_row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // Total blocks along col dim for scales/zp + // const auto total_col_blks = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + + // Total number of blocks to process + // const auto total_quant_blocks = total_row_blks * total_col_blks; + + // Iterate over the quantization blocks assigned to this thread + // Each thread might process multiple QuantBlks + // This loop structure assumes 1D grid/block launch. A 2D launch might map threads differently. + const auto block_idx = thread_idx_global; // Assuming 1 thread processes 1 ThreadBlk here + + // Calculate row and column block indices for this thread + // Map 1D block_idx back to 2D block indices (row_blk, col_blk) + const auto r_blk_idx_thread = static_cast(block_idx % thread_row_blocks); // Thread block index along rows + const auto c_blk_idx_thread = static_cast(block_idx / thread_row_blocks); // Thread block index along columns + + // Calculate starting row and column for this thread's work item (ThreadBlk) + int32_t r_start = r_blk_idx_thread * ThreadBlk::kRow; + int32_t c_start = c_blk_idx_thread * ThreadBlk::kColumn; + + // Check if this thread is out of bounds for the overall work + if (c_start >= columns) { + return; + } + + // Determine the actual end row/column considering matrix boundaries + int32_t r_end = std::min(r_start + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c_start + ThreadBlk::kColumn, columns); + + // Process elements within the assigned ThreadBlk + for (int32_t c = c_start; c < c_end; ++c) { + // Calculate the block index for scale/zp lookup based on the current column 'c' + const auto scale_zp_col_blk_idx = c / QuantBlk::kColumn; + + // Calculate base pointer for this column in the quantized weights matrix + // Assuming weights stored column-major: shape [rows, columns] -> layout [columns, rows] + // Each element is uint8_t. + // const uint8_t* q_col_ptr = weights + static_cast(scale_zp_col_blk_idx) * rows; + + for (int32_t r = r_start; r < r_end; ++r) { + // Calculate the block index for scale/zp lookup based on current row 'r' + const auto scale_zp_row_blk_idx = r / QuantBlk::kRow; + const auto scale_zp_flat_idx = scale_zp_col_blk_idx * total_row_blks + scale_zp_row_blk_idx; + + // Get scale and zero point for this block + const ElementT scale = scales[scale_zp_flat_idx]; + const uint8_t zp_uint8 = (zero_points == nullptr) ? 128 : zero_points[scale_zp_flat_idx]; + + // Get the quantized value (uint8_t) + // Assuming weights are stored col-major for block quantization (e.g. [cols, rows/block_size, block_size]) + // Row-major logical layout for weights access: index = c * rows + r + const size_t q_val_idx = static_cast(c) * rows + r; + const uint8_t q_val = weights[q_val_idx]; + + // Dequantize + if constexpr (std::is_same::value) { + const half zp_half = __uint2half_rn(zp_uint8); + const half adjust = -scale * zp_half; + dst[q_val_idx] = __uint2half_rn(q_val) * scale + adjust; + } else { // Float + static_assert(std::is_same::value, "Only float and half are supported!"); + const float zp_float = static_cast(zp_uint8); + const float adjust = -scale * zp_float; + dst[q_val_idx] = static_cast(q_val) * scale + adjust; + } + } + } +} + +// Launcher function for the generic 8-bit kernel +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize8b_generic(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + const auto thread_row_blocks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thread_col_blocks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto thread_total_blocks = thread_row_blocks * thread_col_blocks; + + const auto grids = (thread_total_blocks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread8b<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thread_row_blocks); +} + +template +Status +DequantizeBlockwise8b( + T* dst, + const uint8_t* src, // Quantized uint8_t data + const T* scales, + const uint8_t* zero_points, // Assuming uint8_t zero points + int block_size, + bool columnwise, // Orientation of elements within a block + int rows, + int columns, + cudaStream_t stream) { + // Use the generic launcher, passing qbits=8 + switch (block_size) { + case 16: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for 8b blockwise quantization."); + } +} + +// Template instantiations for 8-bit blockwise +template Status DequantizeBlockwise8b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template Status DequantizeBlockwise8b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu similarity index 93% rename from onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu rename to onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu index ce6c07fbed2bc..5d634b8a929f1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu @@ -1,4 +1,3 @@ -// Modifications: scaling is moved from masked softmax to the gemm before that. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -89,7 +88,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +119,7 @@ __device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, h sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +143,7 @@ __device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, h } #endif -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); @@ -178,7 +177,7 @@ constexpr int kWarpSize = GPU_WARP_SIZE; // Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], // i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) template -__global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatInt4Kernel( +__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt4Kernel( T* output, const T* a_data, const uint8_t* b_data_quant, @@ -238,7 +237,7 @@ __global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatIn if constexpr (has_zero_point) { \ zp = b_zp_vec[t_meta_k + k_per_iter / block_size * i]; \ } \ - AccumulateEightElements(value, scale, zp, a_data + k_id + i * k_per_iter, sums); \ + AccumulateEightElements4b(value, scale, zp, a_data + k_id + i * k_per_iter, sums); \ } \ b_data_quant += k_per_iter / 2 * kUnroll; \ t_meta_k += k_per_iter / block_size * kUnroll; \ @@ -258,7 +257,7 @@ __global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatIn if constexpr (has_zero_point) { zp = b_zp_vec[t_meta_k]; } - AccumulateEightElements(value, scale, zp, a_data + k_id, sums); + AccumulateEightElements4b(value, scale, zp, a_data + k_id, sums); } float sum = (float)(sums[0] + sums[1] + sums[2] + sums[3] + sums[4] + sums[5] + sums[6] + sums[7]); @@ -283,7 +282,7 @@ bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream) { if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { return false; @@ -291,8 +290,8 @@ bool TryMatMul4Bits( dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; - int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + - (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); + size_t shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + + static_cast(zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); if (shared_mem_size > shared_mem_per_block) { return false; } @@ -333,7 +332,7 @@ template bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream); template bool TryMatMul4Bits( @@ -346,7 +345,7 @@ template bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu new file mode 100644 index 0000000000000..30fbb486378a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu @@ -0,0 +1,427 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// --- Kernel Configuration Constants --- +constexpr int kColsPerThreadBlock = 8; // Number of columns (N dimension) processed per thread block +constexpr int kElementsPerThreadPerIteration = 8; // Number of elements (K dimension) processed per thread per iteration +constexpr int kWarpSize = GPU_WARP_SIZE; // Typically 32 +constexpr uint8_t kDefaultZeroPoint = 128; // Default zero point if not provided + +// --- Device Function: Accumulate 8 Elements (half precision) --- +// Dequantizes 8 uint8_t values and accumulates the result with 8 half values from A. +// sums += A * dequant(B_quant) +__device__ __forceinline__ void AccumulateEightElements8b( + uint64_t values_quant, // 8 packed uint8_t values from B + half scale, // Dequantization scale for this block + uint8_t zp, // Dequantization zero point for this block + const half* a, // Pointer to 8 half values from A + half* sums) { // Pointer to 8 partial sums (half) + +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530) + // --- Dequantization Setup --- + half2 scale_h2 = __half2half2(scale); // Broadcast scale + half zp_h = __ushort2half_rn(zp); // Convert zp to half + half2 zp_h2 = __half2half2(zp_h); // Broadcast zp_h + + // --- Extract 8 uint8_t values --- + uint8_t q[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + q[i] = (values_quant >> (i * 8)) & 0xFF; + } + + // --- Dequantize 8 values into 4 half2 vectors: b_vec = (q - zp) * scale --- + half2 q_01 = __halves2half2(__ushort2half_rn(q[0]), __ushort2half_rn(q[1])); + half2 q_23 = __halves2half2(__ushort2half_rn(q[2]), __ushort2half_rn(q[3])); + half2 q_45 = __halves2half2(__ushort2half_rn(q[4]), __ushort2half_rn(q[5])); + half2 q_67 = __halves2half2(__ushort2half_rn(q[6]), __ushort2half_rn(q[7])); + + half2 diff_01 = __hsub2(q_01, zp_h2); + half2 diff_23 = __hsub2(q_23, zp_h2); + half2 diff_45 = __hsub2(q_45, zp_h2); + half2 diff_67 = __hsub2(q_67, zp_h2); + + half2 b_vec0 = __hmul2(diff_01, scale_h2); // {b0, b1} + half2 b_vec1 = __hmul2(diff_23, scale_h2); // {b2, b3} + half2 b_vec2 = __hmul2(diff_45, scale_h2); // {b4, b5} + half2 b_vec3 = __hmul2(diff_67, scale_h2); // {b6, b7} + + // --- Load Input A (8 half values as 4 half2 vectors) --- + // Assumes 'a' is properly aligned for half2 reads. + const half2* a_half2 = reinterpret_cast(a); + half2 a_vec0 = a_half2[0]; // {a0, a1} + half2 a_vec1 = a_half2[1]; // {a2, a3} + half2 a_vec2 = a_half2[2]; // {a4, a5} + half2 a_vec3 = a_half2[3]; // {a6, a7} + + // --- Accumulate: sums += a * b_vec using half2 FMA --- + half2* sums_half2 = reinterpret_cast(sums); + sums_half2[0] = __hfma2(a_vec0, b_vec0, sums_half2[0]); // {s0+=a0*b0, s1+=a1*b1} + sums_half2[1] = __hfma2(a_vec1, b_vec1, sums_half2[1]); // {s2+=a2*b2, s3+=a3*b3} + sums_half2[2] = __hfma2(a_vec2, b_vec2, sums_half2[2]); // {s4+=a4*b4, s5+=a5*b5} + sums_half2[3] = __hfma2(a_vec3, b_vec3, sums_half2[3]); // {s6+=a6*b6, s7+=a7*b7} + +#else // older GPUs of compute capability < 5.3, which lacks native half support. + float scale_f = __half2float(scale); + float zp_f = static_cast(zp); + + float b_dequant[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + uint8_t q = (values_quant >> (i * 8)) & 0xFF; + b_dequant[i] = (static_cast(q) - zp_f) * scale_f; + } + +#pragma unroll + for (int i = 0; i < 8; ++i) { + float a_f = __half2float(a[i]); + float product_f = a_f * b_dequant[i]; + // Convert back to half for partial sums. It is not ideal for performance. + half product_h = __float2half_rn(product_f); + sums[i] += product_h; + } +#endif +} + +// --- Device Function: Accumulate 8 Elements (float precision) --- +// Dequantizes 8 uint8_t values and accumulates the result with 8 float values from A. +// sums += A * dequant(B_quant) +__device__ __forceinline__ void AccumulateEightElements8b( + uint64_t values_quant, // 8 packed uint8_t values from B + float scale, // Dequantization scale for this block + uint8_t zp, // Dequantization zero point for this block + const float* a, // Pointer to 8 float values from A + float* sums) { // Pointer to 8 partial sums (float) + + // Load A using float4 for potentially better memory bandwidth + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + // Precompute scale * (-zp) adjustment + float zp_adjust = -scale * float(zp); + + // Extract, dequantize, and accumulate 8 float values + float v[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + uint8_t q_val = (values_quant >> (i * 8)) & 0xFF; + // Dequantize: float(q_val) * scale - scale * float(zp) = float(q_val) * scale + zp_adjust + v[i] = float(q_val) * scale + zp_adjust; + } + + // Accumulate using fmaf (fused multiply-add) + sums[0] = fmaf(v[0], a_vec_0.x, sums[0]); + sums[1] = fmaf(v[1], a_vec_0.y, sums[1]); + sums[2] = fmaf(v[2], a_vec_0.z, sums[2]); + sums[3] = fmaf(v[3], a_vec_0.w, sums[3]); + sums[4] = fmaf(v[4], a_vec_1.x, sums[4]); + sums[5] = fmaf(v[5], a_vec_1.y, sums[5]); + sums[6] = fmaf(v[6], a_vec_1.z, sums[6]); + sums[7] = fmaf(v[7], a_vec_1.w, sums[7]); +} + +// --- CUDA Kernel: MatMulFloat8bKernel (Optimized for m=1) --- +// Computes C(1, N) = A(1, K) x B(K, N) +// B(K, N) is quantized with 8 bits and block_size bs, stored as [N, K/bs, bs] +// +// Template Parameters: +// T: Data type for A and C (float or half) +// block_size: Quantization block size for B +// has_zero_point: Boolean indicating if zero points are provided +// +// Grid size: (Ceil(N / kColsPerThreadBlock), 1) +// Block size: (kWarpSize, kColsPerThreadBlock) = (32, 8) +template +__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bKernelM1( + T* output, // Output C [1, N] (effectively [N]) + const T* a_data, // Input A [1, K] (effectively [K]) + const uint8_t* b_data_quant, // Quantized Input B [N, K/bs, bs] + const T* scales_data, // Scales [N, K/bs] + const uint8_t* zero_points, // Zero Points [N, K/bs] (optional) + int n, // Columns in B and C (Constraint: N % kColsPerThreadBlock == 0) + int k, // Columns in A / Rows in B (Constraint: K % kElementsPerThreadPerIteration == 0) + int blocks_per_K) { // K / block_size (rounded up) + + // --- Thread Indexing --- + const int n_block_id = blockIdx.x; // Block column index [0, Ceil(N / kColsPerThreadBlock)) + // m_id is implicitly 0 since blockDim.y is 1 + + const int lane_id = threadIdx.x; // Thread index in warp (0..31) + const int warp_id = threadIdx.y; // Warp index in block (0..kColsPerThreadBlock-1) + + // Calculate the starting column index (n_id) this warp is responsible for + const int n_block_head = n_block_id * kColsPerThreadBlock; + const int n_id = n_block_head + warp_id; // Global output column index for this warp + + // Boundary check for the column index (safety check, though N % kColsPerThreadBlock==0 is enforced) + if (n_id >= n) return; + + // --- Shared Memory Allocation --- + extern __shared__ char shared_buffer[]; + // Shared memory for scales + T* b_scale_vec_shared = reinterpret_cast(shared_buffer); + // Shared memory for zero points (if used) immediately after scales + [[maybe_unused]] uint8_t* b_zp_vec_shared = nullptr; // Initialize to avoid unused warning + if constexpr (has_zero_point) { + b_zp_vec_shared = reinterpret_cast(b_scale_vec_shared + kColsPerThreadBlock * blocks_per_K); + } + + // --- Load Scales and Zero Points into Shared Memory --- + // Each thread loads a portion of the scales/ZPs for the columns handled by this block + for (int i = threadIdx.y * kWarpSize + threadIdx.x; // Linear thread index within the block + i < kColsPerThreadBlock * blocks_per_K; // Total elements to load for the block + i += kColsPerThreadBlock * kWarpSize) { // Stride by total threads in block + int current_n_offset = i / blocks_per_K; // Column offset within the block [0, kColsPerThreadBlock-1] + int current_k_block = i % blocks_per_K; // K block index [0, blocks_per_K-1] + int current_n = n_block_head + current_n_offset; // Global N index + + if (current_n < n) { // Boundary check for N + // Calculate global index into scales/ZPs: N * blocks_per_K + k_block + int64_t scale_zp_idx = static_cast(current_n) * blocks_per_K + current_k_block; + // Load scale + b_scale_vec_shared[i] = scales_data[scale_zp_idx]; + // Load zero point if applicable + if constexpr (has_zero_point) { + b_zp_vec_shared[i] = zero_points[scale_zp_idx]; + } + } + } + + __syncthreads(); // Ensure all scales and ZPs are loaded before proceeding + + // --- Pointers Setup --- + // A data pointer (since m=1, no row offset needed) + const T* a_row_data = a_data; + + // Each thread calculates its part of the dot product along K. + // Point to the start of the elements this thread is responsible for in A. + const int lane_offset = lane_id * kElementsPerThreadPerIteration; // Offset in K for this thread + const T* a_thread_data_base = a_row_data + lane_offset; // Base pointer in A for this thread + + // Base pointer to B data for the specific column n_id this warp handles. + // Layout of B is [N, K/bs, bs]. + const uint8_t* b_base_ptr_n = b_data_quant + static_cast(n_id) * blocks_per_K * block_size; + + // Pointer to the start of scales for column n_id (from shared memory) + const T* b_scale_vec_thread = b_scale_vec_shared + warp_id * blocks_per_K; + + // Pointer to the start of zero points for column n_id (from shared memory, if used) + [[maybe_unused]] const uint8_t* b_zp_vec_thread = nullptr; // Initialize to avoid unused warning + if constexpr (has_zero_point) { + b_zp_vec_thread = b_zp_vec_shared + warp_id * blocks_per_K; + } + + // --- Accumulation --- + // Initialize partial sums for this thread to zero + // Note that partial sum uses original data type. It is a trade-off between performance and accuracy. + // For example, K=3072, each accumulates k / k_per_iter = 3072 / 256 = 12 elements. + T sums[kElementsPerThreadPerIteration] = {static_cast(0.0f)}; + + constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration; // Elements processed per warp per iteration (e.g., 32*8 = 256) + int k_id = 0; // Current position along the K dimension + + // Pointer to B data for this thread's starting element in K, for column n_id. + const uint8_t* b_data_quant_thread = b_base_ptr_n + lane_offset; + + for (; k_id + k_per_iter <= k; k_id += k_per_iter) { + const uint8_t* current_b_ptr = b_data_quant_thread + k_id; + uint64_t value = *reinterpret_cast(current_b_ptr); + + int current_meta_k = (lane_offset + k_id) / block_size; + T scale = b_scale_vec_thread[current_meta_k]; + uint8_t zp = kDefaultZeroPoint; + if constexpr (has_zero_point) { + zp = b_zp_vec_thread[current_meta_k]; + } + + AccumulateEightElements8b(value, scale, zp, a_thread_data_base + k_id, sums); + } + + // Handle the tail elements along K dimension for this thread. + // This loop handles the final iteration if k is not a multiple of k_per_iter. + // Since K % kElementsPerThreadPerIteration == 0 is enforced, each thread + // processes a full set of kElementsPerThreadPerIteration if it has work left. + if (lane_offset + k_id < k) { // Check if this thread has remaining elements + const uint8_t* current_b_ptr = b_data_quant_thread + k_id; + uint64_t value = *reinterpret_cast(current_b_ptr); + + // Calculate k_block index for the tail part + int current_meta_k = (lane_offset + k_id) / block_size; + T scale = b_scale_vec_thread[current_meta_k]; + uint8_t zp = kDefaultZeroPoint; + if constexpr (has_zero_point) { + zp = b_zp_vec_thread[current_meta_k]; + } + // Pointer to A data for the tail part + const T* current_a_ptr = a_thread_data_base + k_id; + // Perform dequantization and accumulation + AccumulateEightElements8b(value, scale, zp, current_a_ptr, sums); + } + + // --- Intra-Thread Reduction --- + // Sum the kElementsPerThreadPerIteration partial sums within each thread. + // Here we use float to accumulate to avoid precision loss. + float total_sum_thread = 0.0f; + +#pragma unroll + for (int i = 0; i < kElementsPerThreadPerIteration; ++i) { + total_sum_thread += static_cast(sums[i]); + } + + // --- Inter-Thread Reduction (Warp Level) --- + // Use CUB for efficient and robust warp reduction + using BlockReduce = cub::WarpReduce; + // Allocate shared memory for CUB temporary storage (one per warp) + __shared__ typename BlockReduce::TempStorage temp_storage[kColsPerThreadBlock]; + + // Perform warp-level sum reduction. Use float in accumulation to avoid precision loss. + total_sum_thread = BlockReduce(temp_storage[warp_id]).Sum(total_sum_thread); + + // Lane 0 of each warp writes the final reduced sum to global memory + if (lane_id == 0) { + // Write result (cast back to T) + // Since m=1, output index is just n_id + output[n_id] = static_cast(total_sum_thread); + } +} + +// --- Host Function: TryMatMul8Bits (Optimized for m=1) --- +// Launches the MatMulFloat8bKernelM1 kernel if constraints are met. +// Enforces m == 1. +template +bool TryMatMul8Bits( + T* output, // Output C [1, N] + const T* a_data, // Input A [1, K] + const uint8_t* b_data_quant, // Input B Quantized [N, K/bs, bs] + const T* scales_data, // Scales [N, K/bs] + const uint8_t* zero_points, // Zero Points [N, K/bs] (can be nullptr) + int m, // Rows of A and C (MUST be 1) + int n, // Columns of B and C + int k, // Columns of A / Rows of B + int block_size, // Quantization block size for B + size_t shared_mem_per_block, // Available shared memory per block + cudaStream_t stream) { + // Constraints Check + // m must be 1 (since this kernel is optimized for m=1) + // N must be a multiple of kColsPerThreadBlock (8) for warps to align with columns. + // K must be a multiple of kElementsPerThreadPerIteration (8) for full uint64_t reads/processing. + if (m != 1 || n % kColsPerThreadBlock != 0 || k % kElementsPerThreadPerIteration != 0) { + return false; + } + + // Ensure k_per_iter (kWarpSize * kElementsPerThreadPerIteration) is multiple of block_size. + constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration; + if (k_per_iter % block_size != 0) { + // This constraint is needed for the scale/zp indexing calculation within the unrolled loop. + return false; + } + + // K must be a multiple of block_size for correct scale/zp lookup within blocks. + // While blocks_per_K handles rounding up, the kernel logic assumes full blocks for indexing. + if (k % block_size != 0) { + return false; + } + + // --- Grid and Thread Block Configuration --- + dim3 threads(kWarpSize, kColsPerThreadBlock); // Block dimensions (32, 8) + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + + // Calculate K / block_size (no rounding needed due to k % block_size == 0 check) + int blocks_per_K = k / block_size; + + // --- Shared Memory Calculation --- + // Memory for scales + optional zero points for the columns handled by the block + size_t scale_zp_shared_mem = (sizeof(T) + (zero_points != nullptr ? sizeof(uint8_t) : 0)) * + static_cast(blocks_per_K) * kColsPerThreadBlock; + + size_t total_shared_mem = scale_zp_shared_mem; + + // Add shared memory for CUB reduction storage if used + total_shared_mem += static_cast(kColsPerThreadBlock) * sizeof(typename cub::WarpReduce::TempStorage); + + // Check if required shared memory exceeds device limits for the block + if (total_shared_mem > shared_mem_per_block) { + return false; + } + + // --- Kernel Launch --- + // Macro to simplify dispatching based on block size and presence of zero_points +#define MatMulFloat8bKernelM1Dispatch(bs) \ + if (nullptr != zero_points) { \ + /* Launch kernel with zero points */ \ + MatMulFloat8bKernelM1<<>>( \ + output, a_data, b_data_quant, scales_data, zero_points, n, k, blocks_per_K); \ + } else { \ + /* Launch kernel without zero points (passing nullptr) */ \ + MatMulFloat8bKernelM1<<>>( \ + output, a_data, b_data_quant, scales_data, nullptr /*zero_points*/, n, k, blocks_per_K); \ + } + + // Dispatch based on the provided block_size value + // Note: Only block sizes compatible with k_per_iter % block_size == 0 and k % block_size == 0 will pass checks. + if (16 == block_size) { + MatMulFloat8bKernelM1Dispatch(16); + } else if (32 == block_size) { + MatMulFloat8bKernelM1Dispatch(32); + } else if (64 == block_size) { + MatMulFloat8bKernelM1Dispatch(64); + } else if (128 == block_size) { + MatMulFloat8bKernelM1Dispatch(128); + } else if (256 == block_size) { + MatMulFloat8bKernelM1Dispatch(256); + } else { + // Unsupported block size + return false; + } + +#undef MatMulFloat8bKernelM1Dispatch + + // Kernel launch succeeded (error checking, e.g., cudaGetLastError(), should be done by the caller) + return true; +} + +// --- Template Instantiations --- +template bool TryMatMul8Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul8Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1cec6f6a12f1c..33265744f3a7d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -8,6 +8,7 @@ #include "core/common/status.h" #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -23,6 +24,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); + const Tensor* bias = ctx->Input(5); + if (bias != nullptr) { + ORT_THROW("MatMulNBits does not support bias in CUDA kernel"); + } const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); @@ -40,80 +45,133 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { helper.Compute(a->Shape(), b_shape, transa, transb)); Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) return Status::OK(); - - bool is_4bit_done = (reorder_idx_data == nullptr) && - (!zero_points || !zero_points->IsDataType()) && - TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - - if (is_4bit_done) { + if (Y->Shape().Size() == 0) return Status::OK(); + + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + bool done = (nbits_ == 8) ? TryMatMul8Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + GetDeviceProp().sharedMemPerBlock, + static_cast(ctx->GetComputeStream()->GetHandle())) + : TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + GetDeviceProp().sharedMemPerBlock, + static_cast(ctx->GetComputeStream()->GetHandle())); + if (done) { + return Status::OK(); + } } int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - if (reorder_idx) { - ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); - } - // column-wise block - if ((zero_points && zero_points->IsDataType())) { - ORT_RETURN_IF_ERROR(Dequantize4Bits( + + if (nbits_ == 8) { + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + if (zero_points && zero_points->IsDataType()) { + ORT_RETURN_IF_ERROR(Dequantize8Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const CudaT*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + ORT_RETURN_IF_ERROR(Dequantize8Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } + } else { // row-wise block + ORT_RETURN_IF_ERROR(DequantizeBlockwise8b( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - (const CudaT*)zero_points_data, - reorder_idx_data, - SafeInt(K_padded), - SafeInt(N_), + (const uint8_t*)zero_points_data, SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), static_cast(ctx->GetComputeStream()->GetHandle()))); + } + } else { // 4 bits + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const CudaT*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } } else { - ORT_RETURN_IF_ERROR(Dequantize4Bits( + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), (const uint8_t*)zero_points_data, - reorder_idx_data, - SafeInt(K_padded), - SafeInt(N_), SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), static_cast(ctx->GetComputeStream()->GetHandle()))); } - } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( - reinterpret_cast(b_data), - blob_data, - reinterpret_cast(scales_data), - (const uint8_t*)zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), - SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); } -#if 0 -cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); -T* b_data_cpu = new T[K_ * N_]; -cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); -delete[] b_data_cpu; -#endif + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); const CudaT alpha = ToCudaType::FromFloat(1.f); const CudaT zero = ToCudaType::FromFloat(0.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index 9ccbe4c4d97a8..fe7098b92cba8 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -19,7 +19,21 @@ bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, + cudaStream_t stream); + +template +bool TryMatMul8Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, cudaStream_t stream); } // namespace cuda diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 6e7919f281fb6..936b4483201ac 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -155,7 +155,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n" << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n" - << " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n" + << " let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; @@ -446,7 +446,8 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = parameters.total_sequence_length_; + const int total_sequence_length = + parameters.is_gqa_ ? parameters.total_sequence_length_ : past_sequence_length + parameters.kv_sequence_length_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length}); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index ad8319aeff1ad..e3353c921094a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -9,7 +9,9 @@ namespace contrib { namespace webgpu { namespace { -constexpr std::string_view commonFunctions = R"ADDNL_FN( +std::string CommonFunctions(uint32_t nbits) { + if (nbits == 4) { + return R"ADDNL_FN( fn DequantizedFrom4BitsTo8Bits(in: vec2) -> vec4 { var out = vec4(0); @@ -38,6 +40,37 @@ constexpr std::string_view commonFunctions = R"ADDNL_FN( return output_element_t(local_sum) * scale; } )ADDNL_FN"; + } else { + ORT_ENFORCE(nbits == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); + // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. + // Then do the scale. Finally, convert to output element type. + return R"ADDNL_FN( + fn AlignWithZeroPoint(in: vec4) -> vec4 + { + var out = vec4(0); + out[0] = pack4xI8(vec4(unpack4xU8(in[0])) - vec4(128)); + out[1] = pack4xI8(vec4(unpack4xU8(in[1])) - vec4(128)); + out[2] = pack4xI8(vec4(unpack4xU8(in[2])) - vec4(128)); + out[3] = pack4xI8(vec4(unpack4xU8(in[3])) - vec4(128)); + return out; + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(f32(local_sum) * f32(scale)); + } + )ADDNL_FN"; + } +} } // namespace @@ -98,7 +131,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // this shader require A to be int8 quantized with block size 64. B is regular // matmulnbits input with block size 32. - shader.AdditionalImplementation() << commonFunctions + shader.AdditionalImplementation() << CommonFunctions(nbits_) << " const block_size = " << block_size_ << ";"; shader.AdditionalImplementation() << R"ADDNL_FN( @@ -129,7 +162,9 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; } } - + )ADDNL_FN"; + if (nbits_ == 4) { + shader.AdditionalImplementation() << R"ADDNL_FN( fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) { let b_global = b_global_base + row; @@ -147,6 +182,27 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } )ADDNL_FN"; + } else { + ORT_ENFORCE(nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + tile_B[col][row] = AlignWithZeroPoint(b_value); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + } + } + )ADDNL_FN"; + } shader.MainFunctionBody() << R"MAIN_FN( // During the load phase we use all 256 threads to load 64 rows of A/B. @@ -289,18 +345,19 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co // - Stores intermediate results in shared memory (inter_results) // - Iterates through columns accumulating results in inter_results // - Performs final reduction sum in inter_results for output - shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" - << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" - // sub_tile_size is the number of concurrent b rows processed by the workgroup. - << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; - shader.AdditionalImplementation() << commonFunctions + shader.AdditionalImplementation() << " const tile_size = " << tile_size_ << "u;\n" + << " const tile_size_k_vec = " << tile_size_k_vec << "u;\n" + << " const double_tile_size_k_vec = " << 2 * tile_size_k_vec << "u;\n" + // sub_tile_count is the number of concurrent b rows processed by the workgroup. + << " const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n" + << " var inter_results: array, tile_size>;\n"; + + shader.AdditionalImplementation() << CommonFunctions(nbits_) << R"ADDNL_FN( - // Shared memory // Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. var tile_A : array, 32>; // Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128. var scale_A : array; - var inter_results: array, tile_size>; fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) { let k_offset = kidx_v + col; @@ -320,13 +377,13 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.MainFunctionBody() << R"MAIN_FN( let a_global = u32(workgroup_idx / uniforms.num_N_tile); let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; - // Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec] + // Handle each workgroup threads as a block of [sub_tile_count][tile_size_k_vec] let local_col = local_idx % tile_size_k_vec; let local_row = local_idx / tile_size_k_vec; for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) { // Load Phase: Populate shared memory for the workgroup. - if (local_idx < 32) + if (local_idx < double_tile_size_k_vec) { loadSHMA(a_global, kidx_v * 2, local_idx); } @@ -338,15 +395,25 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co var own_b1 = vec4(0); let k_offset = kidx_v + local_col; // calculate intermediate results into inter_results. - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { let b_global = b_global_base + row_offset + local_row; if (b_global < uniforms.N && k_offset < uniforms.K32) { let b_offset = b_global * uniforms.K32 + k_offset; + )MAIN_FN"; + if (nbits_ == 4) { + shader.MainFunctionBody() << R"MAIN_FN( let b_value = input_b[b_offset]; own_b = DequantizedFrom4BitsTo8Bits(b_value.xy); own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw); - + )MAIN_FN"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + own_b = AlignWithZeroPoint(input_b[b_offset * 2]); + own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]); + )MAIN_FN"; + } + shader.MainFunctionBody() << R"MAIN_FN( // k_offset - covers 32 values of k in input_b let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); @@ -378,6 +445,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t K, uint32_t block_size, uint32_t min_M_for_tile_optimization, + uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { constexpr uint32_t kVec4Components = 4; @@ -399,7 +467,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor if (M < min_M_for_tile_optimization) { constexpr uint32_t kTileSize = 32; - DP4AMatMulNBitsSmallMProgram mul_program{kTileSize}; + DP4AMatMulNBitsSmallMProgram mul_program{kTileSize, nbits}; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); @@ -408,7 +476,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) + .CacheHint(nbits); return context.RunProgram(mul_program); } @@ -416,12 +485,12 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor TensorShape reshaped_y_shape{1, M, N / kVec4Components}; uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; - DP4AMatMulNBitsProgram mul_program{block_size}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec2Components * kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kVec2Components * kU32Components : kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, @@ -430,7 +499,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {static_cast(K / 16)}, {num_N_tile}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size)); + .CacheHint("Block" + std::to_string(block_size), nbits); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 67e2e7d66e83a..647e200ce93b7 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -20,7 +20,7 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} + DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits) : Program{"DP4AMatMulNBits"}, block_size_(block_size), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -32,11 +32,12 @@ class DP4AMatMulNBitsProgram final : public Program { private: uint32_t block_size_; + uint32_t nbits_; }; class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram(uint32_t tile_size) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size) {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size, uint32_t nbits) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -49,6 +50,7 @@ class DP4AMatMulNBitsSmallMProgram final : public Program(helper.N()); const uint32_t K = onnxruntime::narrow(helper.K()); const uint32_t block_size = onnxruntime::narrow(block_size_); - constexpr uint32_t nbits = 4; + const uint32_t nbits = onnxruntime::narrow(bits_); const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; const uint32_t blob_size = (block_size / 8) * nbits; @@ -720,16 +720,19 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // macOS - Experimental dawn support for subgroup matrix matmul on Metal. if (M >= kMinMForTileOptimization && CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { - return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); + return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, nbits, context, y); } // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. - if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || nbits == 8 || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, nbits, context, y); } + // TODO: Remvoe it once the 8bits is supported for the non-dp4 path. + ORT_ENFORCE(nbits == 4, "Only 4 bits are supported for the non-dp4 path for webgpu matmulnbits"); + // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. // TODO: loosen restrictions on vendor. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index d8f4d74da4876..d5e4bc68fc33a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -55,10 +55,10 @@ class MatMulNBits final : public WebGpuKernel { K_ = info.GetAttr("K"); N_ = info.GetAttr("N"); block_size_ = info.GetAttr("block_size"); - int64_t bits = info.GetAttr("bits"); + bits_ = info.GetAttr("bits"); accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); - ORT_ENFORCE(bits == 4, - "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(bits_ == 4 || bits_ == 8, + "Only 4b/8b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; @@ -68,6 +68,7 @@ class MatMulNBits final : public WebGpuKernel { int64_t N_; int64_t block_size_; int64_t accuracy_level_; + int64_t bits_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index b1dce049214eb..09650be9358d0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -41,7 +41,9 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]); } } - + )ADDNL_FN"; + if (nbits_ == 4) { + shader.AdditionalImplementation() << R"ADDNL_FN( fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { let b_global = tile_base + row; if (b_global >= uniforms.N) { @@ -69,7 +71,40 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_B[tile_b_base + 7] = b_value_upper[3]; } } - + )ADDNL_FN"; + } else { + ORT_ENFORCE(nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 16 columns, starting at col. + var col = c_idx * 16; + // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. + // Stored in column major fashion. + let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + for (var step:u32 = 0; step < 2; step++) + { + var b_value = input_b[b_idx+step]; + var b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(128)) * scale; + var b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(128)) * scale; + let tile_b_base = row * tile_k + col + step * 8; + tile_B[tile_b_base] = b_value0[0]; + tile_B[tile_b_base + 1] = b_value0[1]; + tile_B[tile_b_base + 2] = b_value0[2]; + tile_B[tile_b_base + 3] = b_value0[3]; + tile_B[tile_b_base + 4] = b_value1[0]; + tile_B[tile_b_base + 5] = b_value1[1]; + tile_B[tile_b_base + 6] = b_value1[2]; + tile_B[tile_b_base + 7] = b_value1[3]; + } + } + )ADDNL_FN"; + } + shader.AdditionalImplementation() << R"ADDNL_FN( fn storeOutput(offset:u32, row: u32, col:u32, src_slot:u32, row_limit:i32) { if (row_limit > 0 && row < u32(row_limit)) { @@ -174,24 +209,26 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te uint32_t M, uint32_t N, uint32_t K, + uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { constexpr uint32_t kTileSizeA = 32; constexpr uint32_t kTileSizeB = 64; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; - SubgroupMatrixMatMulNBitsProgram mul_program; + SubgroupMatrixMatMulNBitsProgram mul_program{nbits}; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize( (N + kTileSizeB - 1) / kTileSizeB, (M + kTileSizeA - 1) / kTileSizeA, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, - {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kU32Components)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) + .CacheHint(nbits); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 57a0b1066326a..a233e6a54f4fc 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -17,18 +17,22 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram() : Program{"SubgroupMatrixMatMulNBits"} {} + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits) : Program{"SubgroupMatrixMatMulNBits"}, nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, {"N", ProgramUniformVariableDataType::Uint32}, {"K", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t nbits_; }; Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t M, uint32_t N, uint32_t K, + uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, Tensor* y); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 80e0fe1ae3484..97766028cfe12 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -42,8 +42,6 @@ #include -#define HAS_WINDOWS_DESKTOP WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) - #ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE #define PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE 43 #endif @@ -207,21 +205,22 @@ void CPUIDInfo::ArmWindowsInit() { // Get the ARM vendor string from the registry vendor_ = GetArmWindowsVendor(); -// ARM32 certainly doesn't have fp16, so we will skip the logic to avoid using RegGetValueA Windows API -#if !defined(_M_ARM) -#pragma region Application Family or OneCore Family -#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) - // Read MIDR from windows registry + // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry + // There should be one per CPU + std::vector midr_values{}, id_aa64isar1_el1_values{}; + // TODO!! Don't support multiple processor group yet!! constexpr int MAX_CORES = 64; constexpr int MAX_VALUE_NAME = 4096; - CHAR midrKey[MAX_VALUE_NAME] = ""; // buffer for processor registry name - uint32_t lastUarch = cpuinfo_uarch_unknown; - for (int i = 0; i < MAX_CORES - 1; i++) { - snprintf(midrKey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", i); - uint64_t midrVal; - unsigned long midrSize = sizeof(uint64_t); + CHAR processor_subkey[MAX_VALUE_NAME] = ""; // buffer for processor registry name + + for (size_t i = 0; i < MAX_CORES - 1; i++) { + snprintf(processor_subkey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", + static_cast(i)); + + uint64_t midr_value; + unsigned long data_size = sizeof(midr_value); /* * ARM lists for each coprocessor register 5 fields: op0/op1/CRn/CRm/op2. @@ -236,48 +235,65 @@ void CPUIDInfo::ArmWindowsInit() { * * For the CP value of MIDR, op0 = 3 and the others are all = 0, so we come up with 0x4000, */ - auto retCode = ::RegGetValueA(HKEY_LOCAL_MACHINE, midrKey, "CP 4000", RRF_RT_REG_QWORD, nullptr, &midrVal, &midrSize); - if (retCode != ERROR_SUCCESS) { + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4000", RRF_RT_REG_QWORD, + nullptr, &midr_value, &data_size) != ERROR_SUCCESS) { break; } - uint32_t uarch = cpuinfo_uarch_unknown; - decodeMIDR((uint32_t)midrVal, &uarch); - core_uarchs_.push_back(uarch); - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_.push_back(true); - } else { - is_armv8_narrow_ld_.push_back(false); + + uint64_t id_aa64isar1_el1_value; + data_size = sizeof(id_aa64isar1_el1_value); + + // CP 4031 corresponds to ID_AA64ISAR1_EL1 register + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4031", RRF_RT_REG_QWORD, + nullptr, &id_aa64isar1_el1_value, &data_size) != ERROR_SUCCESS) { + break; } - if (i == 0) { - lastUarch = uarch; - } else if (lastUarch != uarch) { - is_hybrid_ = true; - lastUarch = uarch; + midr_values.push_back(midr_value); + id_aa64isar1_el1_values.push_back(id_aa64isar1_el1_value); + } + + // process midr_values + { + uint32_t lastUarch = cpuinfo_uarch_unknown; + for (size_t i = 0; i < midr_values.size(); ++i) { + uint32_t uarch = cpuinfo_uarch_unknown; + decodeMIDR(static_cast(midr_values[i]), &uarch); + core_uarchs_.push_back(uarch); + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_.push_back(true); + } else { + is_armv8_narrow_ld_.push_back(false); + } + + if (i == 0) { + lastUarch = uarch; + } else if (lastUarch != uarch) { + is_hybrid_ = true; + lastUarch = uarch; + } } } -#endif // WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) + + has_arm_neon_i8mm_ = std::all_of( + id_aa64isar1_el1_values.begin(), id_aa64isar1_el1_values.end(), + [](uint64_t id_aa64isar1_el1_value) { + // I8MM, bits [55:52] + return ((id_aa64isar1_el1_value >> 52) & 0xF) != 0; + }); has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else // ^ !defined(_M_ARM) / v defined(_M_ARM) - has_arm_neon_dot_ = false; -#endif // defined(_M_ARM) #if defined(CPUINFO_SUPPORTED) if (pytorch_cpuinfo_init_) { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); - has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + // cpuinfo_has_arm_i8mm() doesn't work on Windows yet. See https://github.com/pytorch/cpuinfo/issues/279. + // has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && has_arm_neon_i8mm_; has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); - } else -#endif // defined(CPUINFO_SUPPORTED) - { - has_fp16_ = false; - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; } +#endif // defined(CPUINFO_SUPPORTED) } std::string CPUIDInfo::GetArmWindowsVendor() { diff --git a/onnxruntime/core/dll/onnxruntime.rc b/onnxruntime/core/dll/onnxruntime.rc index 4b08dfdb7e0f6..2e7eb7842f935 100644 --- a/onnxruntime/core/dll/onnxruntime.rc +++ b/onnxruntime/core/dll/onnxruntime.rc @@ -2,7 +2,7 @@ // Licensed under the MIT License. // This file REQUIRES the following external definitions: -// VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE, and VER_STRING +// FILE_NAME, VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE, and VER_STRING #include @@ -43,4 +43,4 @@ BEGIN BEGIN VALUE "Translation", 0x409, 1252 END -END \ No newline at end of file +END diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ecd3960107926..62895c0137a78 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "core/common/exceptions.h" #include "core/common/inlined_containers.h" #include "core/common/safeint.h" @@ -139,7 +140,7 @@ class PlannerImpl { const InlinedHashMap& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, SequentialExecutionPlan& plan, - const logging::Logger& logger) + [[maybe_unused]] const logging::Logger& logger) : context_(&context), plan_(plan), parent_node_(parent_node), @@ -149,8 +150,13 @@ class PlannerImpl { kernel_create_info_map_(kernel_create_info_map), subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps), outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), - ort_value_name_idx_map_(ort_value_name_idx_map), + ort_value_name_idx_map_(ort_value_name_idx_map) +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + , logger_(logger) { +#else + { +#endif } Status CreatePlan( @@ -186,10 +192,9 @@ class PlannerImpl { InlinedHashMap value_node_map_; // logger_ is not currently used in a minimal build -#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) - [[maybe_unused]] -#endif +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const logging::Logger& logger_; +#endif // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: struct OrtValueInfo { @@ -725,6 +730,25 @@ class PlannerImpl { ProcessDef(index, graph_viewer_.GetNodeArg(pair.first)); } + // If the suggested_device is also CPU and default mem type, then + // we check which one has higher alignment and use that one if it is so. + // If the suggested device is CPU, but not the default mem type, then + // it is a CPU accessible memory device allocator. They typically have a page aligment + // so that would satisfy the alignment requirement of any other CPU consumers. + // If one device is not on CPU, we default on the one that is CPU. + auto determine_device = [](const OrtDevice& output_device, const OrtDevice& suggested_device) -> OrtDevice { + if (output_device.Type() == OrtDevice::CPU && suggested_device.Type() == OrtDevice::CPU) { + if (output_device.MemType() == OrtDevice::MemType::DEFAULT && + suggested_device.MemType() == OrtDevice::MemType::DEFAULT) { + return (output_device.GetAlignment() >= suggested_device.GetAlignment()) ? output_device : suggested_device; + } else { + return (output_device.MemType() != OrtDevice::MemType::DEFAULT) ? output_device : suggested_device; + } + } else { + return (output_device.Type() == OrtDevice::CPU) ? output_device : suggested_device; + } + }; + InlinedHashSet set_node_arg_has_explicit_consumer; InlinedHashMap map_implicitly_consumed_node_arg_to_ep; @@ -756,6 +780,7 @@ class PlannerImpl { // Add location information if applicable for the provided input def auto process_input = [&graph_inputs, &exec_provider, &p_kernel_def, &is_implicit_input, &set_node_arg_has_explicit_consumer, + &determine_device, &map_implicitly_consumed_node_arg_to_ep, &set_implicitly_consumed_node_arg_has_heterogenous_ep_consumers, this](const NodeArg& input, size_t arg_idx) { @@ -856,9 +881,12 @@ class PlannerImpl { // we have seen plan_.SetLocation(static_cast(index), exec_provider->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault)); } else { - // Default the location to CPU - plan_.SetLocation(static_cast(index), - execution_providers_.Get(CPU)->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault)); + // We want to minimize the amount of copies, so we want at least one + // device to match or match both if they are CPU based. + OrtDevice result = determine_device( + already_seen_ep_for_node_arg->second->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault), + exec_provider->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault)); + plan_.SetLocation(static_cast(index), result); set_implicitly_consumed_node_arg_has_heterogenous_ep_consumers.insert(index); } } @@ -881,7 +909,37 @@ class PlannerImpl { if (!node_output->Exists()) continue; OrtValueIndex index = Index(node_output->Name()); ProcessDef(index, node_output); - plan_.SetLocation(static_cast(index), exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i))); + OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // Downstream nodes of certain providers may require a CPU accessible location override + // to make sure the EP does not incur an unnecessary copy. + // We only do it for CPU based EPs. We are not likely to encounter + // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. + // However, we still ignore them. + if (output_device.Type() == OrtDevice::CPU && + output_device.MemType() == OrtDevice::MemType::DEFAULT) { + const auto& output_name = node_output->Name(); + const auto consumers = graph_viewer_.GetConsumerNodes(output_name); + for (const auto* consumer : consumers) { + if (consumer != nullptr) { + const auto& ep_type = consumer->GetExecutionProviderType(); + auto suggested_device = execution_providers_.Get(ep_type)->GetOrtDeviceByMemType( + OrtMemType::OrtMemTypeCPUInput); + if (suggested_device.Type() == OrtDevice::CPU && + suggested_device.MemType() == OrtDevice::MemType::DEFAULT) { + output_device = determine_device(output_device, suggested_device); + } else if (suggested_device.Type() == OrtDevice::CPU) { + // Edge case: there are more than one downstream nodes that suggest their own CPU accessible + // memory. In that case, we can not win them all, but the chosen device would still make it run + // and reduce a number of copies for some. + output_device = suggested_device; + break; + } + } + } + } +#endif + plan_.SetLocation(static_cast(index), output_device); } } } diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 02dbb3e518783..ba469a89abd00 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -41,8 +41,7 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz } #ifdef USE_MIMALLOC -void* AllocatorDefaultAlloc(size_t size) { - const size_t alignment = MlasGetPreferredBufferAlignment(); +void* AllocatorDefaultAllocAligned(size_t size, size_t alignment) { if (size <= 0) return nullptr; size += MLAS_SYMM_QGEMM_BUF_OVERRUN; void* p; @@ -71,10 +70,18 @@ void AllocatorDefaultFree(void* p) { #endif } +void AllocatorDefaultFreeAligned(void* p, size_t alignment) { +#if defined(_MSC_VER) + mi_free_aligned(p, alignment); #else -void* AllocatorDefaultAlloc(size_t size) { - const size_t alignment = MlasGetPreferredBufferAlignment(); - if (size <= 0) return nullptr; + mi_free(p); +#endif +} + +#else + +void* AllocatorDefaultAllocAligned(size_t size, size_t alignment) { + if (size == 0) return nullptr; size += MLAS_SYMM_QGEMM_BUF_OVERRUN; void* p; #if _MSC_VER @@ -101,14 +108,25 @@ void AllocatorDefaultFree(void* p) { #endif } +void AllocatorDefaultFreeAligned(void* p, size_t /* alignment */) { + AllocatorDefaultFree(p); +} + #endif // USE_MIMALLOC +void* AllocatorDefaultAlloc(size_t size) { + const size_t alignment = MlasGetPreferredBufferAlignment(); + return AllocatorDefaultAllocAligned(size, alignment); +} + void* CPUAllocator::Alloc(size_t size) { - return AllocatorDefaultAlloc(size); + const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment()); + return AllocatorDefaultAllocAligned(size, alignment); } void CPUAllocator::Free(void* p) { - AllocatorDefaultFree(p); + const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment()); + AllocatorDefaultFreeAligned(p, alignment); } void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) { @@ -168,6 +186,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA onnxruntime::QNN_HTP_SHARED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::CPU_ALIGNED_4K, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, static_cast(id1), onnxruntime::kAlloc4KAlignment), + id1, mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); } diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index c5046353ba528..c103aa206f3db 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -529,8 +529,11 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs"); } + // This alignment is used to properly space out individual chunks in mempatterns memory buffer. + const auto alignment = std::max(location.GetAlignment(), kAllocAlignment); + size_t size = 0; - ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, kAllocAlignment, size)); + ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, alignment, size)); // Lazily get the allocator only if needed. AllocatorPtr alloc = nullptr; diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index 423307b4c8fca..dabb81f2afe71 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -52,242 +52,284 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { - 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x88, 0x0d, 0x00, 0x00, - 0xec, 0x06, 0x00, 0x00, 0x68, 0x06, 0x00, 0x00, 0x1c, 0x08, 0x00, 0x00, 0xc8, 0x02, 0x00, 0x00, - 0x2c, 0x03, 0x00, 0x00, 0x80, 0x01, 0x00, 0x00, 0xc0, 0x09, 0x00, 0x00, 0xdc, 0x03, 0x00, 0x00, - 0x6c, 0x09, 0x00, 0x00, 0x64, 0x02, 0x00, 0x00, 0xbc, 0x0c, 0x00, 0x00, 0x04, 0x0d, 0x00, 0x00, - 0xd4, 0x00, 0x00, 0x00, 0x10, 0x04, 0x00, 0x00, 0x04, 0x05, 0x00, 0x00, 0x68, 0x08, 0x00, 0x00, - 0x70, 0x03, 0x00, 0x00, 0xf0, 0x0d, 0x00, 0x00, 0x8c, 0x04, 0x00, 0x00, 0x6c, 0x05, 0x00, 0x00, - 0x94, 0x0a, 0x00, 0x00, 0x44, 0x0c, 0x00, 0x00, 0x28, 0x07, 0x00, 0x00, 0xc4, 0x05, 0x00, 0x00, - 0xc0, 0x09, 0x00, 0x00, 0x08, 0x0a, 0x00, 0x00, 0xb8, 0x08, 0x00, 0x00, 0x90, 0x01, 0x00, 0x00, - 0x5c, 0x07, 0x00, 0x00, 0xbc, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xf2, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, - 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, - 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x60, 0xf2, 0xff, 0xff, 0x64, 0x0b, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x4e, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb8, 0xf2, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x88, 0xf2, 0xff, 0xff, 0x10, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd8, 0xf2, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x70, 0xf2, 0xff, 0xff, 0xac, 0xf2, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, - 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, - 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0xe0, 0xf2, 0xff, 0xff, 0xb8, 0x0a, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xf2, 0xff, 0xff, - 0xf8, 0xf2, 0xff, 0xff, 0xcc, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe6, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xf3, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x20, 0xf3, 0xff, 0xff, 0x50, 0x0a, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6c, 0xf3, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x3c, 0xf3, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x64, 0xf3, 0xff, 0xff, - 0xd4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xb0, 0xf3, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x80, 0xf3, 0xff, 0xff, 0x90, 0x0c, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x6e, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x68, 0xf3, 0xff, 0xff, 0xa4, 0xf3, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, - 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, - 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xe4, 0xf3, 0xff, 0xff, - 0xe0, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xd2, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x3c, 0xf4, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x0c, 0xf4, 0xff, 0xff, 0x8c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5c, 0xf4, 0xff, 0xff, - 0x02, 0x00, 0x00, 0x00, 0xf4, 0xf3, 0xff, 0xff, 0x30, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, - 0x58, 0xf4, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x46, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x40, 0xf4, 0xff, 0xff, 0x7c, 0xf4, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0xa4, 0xf4, 0xff, 0xff, - 0x94, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf0, 0xf4, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf4, 0xff, 0xff, 0x50, 0x0b, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xae, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xf4, 0xff, 0xff, 0xe4, 0xf4, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x38, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, - 0x31, 0x31, 0x00, 0x00, 0x0c, 0xf5, 0xff, 0xff, 0x04, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfa, 0xf4, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xf4, 0xf4, 0xff, 0xff, 0x30, 0xf5, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x88, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x58, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, - 0x3a, 0x31, 0x00, 0x00, 0x7c, 0xf5, 0xff, 0xff, 0x94, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6a, 0xf5, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x64, 0xf5, 0xff, 0xff, 0xa0, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, - 0xc8, 0xf5, 0xff, 0xff, 0x48, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb6, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xb0, 0xf5, 0xff, 0xff, 0xec, 0xf5, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, - 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, - 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x20, 0xf6, 0xff, 0xff, 0xa4, 0x07, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x78, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x48, 0xf6, 0xff, 0xff, 0x50, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x24, 0xf6, 0xff, 0xff, 0x60, 0xf6, 0xff, 0xff, 0x10, 0x07, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xac, 0xf6, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x7c, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0xa4, 0xf6, 0xff, 0xff, - 0xc8, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf0, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x50, 0x09, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xae, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xf6, 0xff, 0xff, 0xe4, 0xf6, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, - 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x14, 0xf7, 0xff, 0xff, - 0xb0, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x6c, 0xf7, 0xff, 0xff, - 0x02, 0x00, 0x00, 0x00, 0x3c, 0xf7, 0xff, 0xff, 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8c, 0xf7, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x24, 0xf7, 0xff, 0xff, 0x60, 0xf7, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, - 0x88, 0xf7, 0xff, 0xff, 0x88, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x70, 0xf7, 0xff, 0xff, 0xac, 0xf7, 0xff, 0xff, 0xc0, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xc8, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, - 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xf7, 0xff, 0xff, 0x20, 0x08, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xde, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xf7, 0xff, 0xff, 0x14, 0xf8, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, - 0x44, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, - 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, - 0x48, 0xf8, 0xff, 0xff, 0x50, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x98, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x30, 0xf8, 0xff, 0xff, 0x6c, 0xf8, 0xff, 0xff, 0x58, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5a, 0xf8, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x94, 0xf8, 0xff, 0xff, - 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, - 0x64, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, - 0x00, 0x00, 0x00, 0x00, 0xcc, 0xf8, 0xff, 0xff, 0xc8, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb6, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xe8, 0xf8, 0xff, 0xff, 0x28, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0xd0, 0xf8, 0xff, 0xff, 0x0c, 0xf9, 0xff, 0xff, 0x60, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x58, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x28, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x73, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x50, 0xf9, 0xff, 0xff, 0xc0, 0x06, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x3e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xf9, 0xff, 0xff, 0x74, 0xf9, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, - 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, - 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, - 0xa8, 0xf9, 0xff, 0xff, 0x68, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x96, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x90, 0xf9, 0xff, 0xff, 0xcc, 0xf9, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, - 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0x98, 0x03, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x50, 0xfa, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xe8, 0xf9, 0xff, 0xff, 0x24, 0xfa, 0xff, 0xff, - 0xa0, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x12, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x7c, 0xfa, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x4c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, - 0x32, 0x31, 0x00, 0x00, 0x7c, 0xfa, 0xff, 0xff, 0x48, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6a, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xd4, 0xfa, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xa4, 0xfa, 0xff, 0xff, - 0xf4, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xf4, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x8c, 0xfa, 0xff, 0xff, - 0xc8, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, - 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0xf4, 0xfa, 0xff, 0xff, - 0x1c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xdc, 0xfa, 0xff, 0xff, - 0x18, 0xfb, 0xff, 0xff, 0x54, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xfb, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, - 0x00, 0x00, 0x00, 0x00, 0x5c, 0xfb, 0xff, 0xff, 0xac, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x4a, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x44, 0xfb, 0xff, 0xff, 0x80, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0xa4, 0xfb, 0xff, 0xff, - 0x6c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x92, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8c, 0xfb, 0xff, 0xff, - 0xc8, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, - 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, 0xf0, 0xfb, 0xff, 0xff, 0x20, 0x04, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xde, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfb, 0xff, 0xff, 0x14, 0xfc, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, - 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x40, 0xfc, 0xff, 0xff, 0xd0, 0x03, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x2e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xfc, 0xff, 0xff, 0x64, 0xfc, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xfc, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x8c, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, - 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfc, 0xff, 0xff, - 0x5c, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa2, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xfc, 0xff, 0xff, - 0xd8, 0xfc, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0xa8, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, 0xfc, 0x00, 0x00, 0x00, 0x28, 0x01, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, - 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, - 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, 0x24, 0xfd, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x7c, 0xfd, 0xff, 0xff, - 0x04, 0x00, 0x00, 0x00, 0x4c, 0xfd, 0xff, 0xff, 0x20, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x98, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x68, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xc0, 0xfd, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0x90, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe8, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x80, 0xfd, 0xff, 0xff, 0xbc, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x14, 0xfe, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x1c, 0xfe, 0xff, 0xff, - 0x03, 0x00, 0x00, 0x00, 0xec, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xe2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x4c, 0xfe, 0xff, 0xff, - 0x07, 0x00, 0x00, 0x00, 0x1c, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x70, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x40, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, - 0x68, 0xfe, 0xff, 0xff, 0xa8, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x50, 0xfe, 0xff, 0xff, 0x8c, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfe, 0xff, 0xff, - 0x54, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa2, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xfe, 0xff, 0xff, - 0xd8, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf6, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xf0, 0xfe, 0xff, 0xff, 0x2c, 0xff, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, - 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, - 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xbc, 0xff, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x8c, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x7e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb0, 0xff, 0xff, 0xff, 0x60, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0xa0, 0xff, 0xff, 0xff, 0xdc, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0xac, 0x09, 0x00, 0x00, + 0xfc, 0x0f, 0x00, 0x00, 0xb8, 0x08, 0x00, 0x00, 0xd8, 0x03, 0x00, 0x00, 0x5c, 0x0c, 0x00, 0x00, + 0x5c, 0x05, 0x00, 0x00, 0x28, 0x0a, 0x00, 0x00, 0x3c, 0x02, 0x00, 0x00, 0xec, 0x02, 0x00, 0x00, + 0x04, 0x0f, 0x00, 0x00, 0xa4, 0x0b, 0x00, 0x00, 0x94, 0x01, 0x00, 0x00, 0xdc, 0x01, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x88, 0x02, 0x00, 0x00, 0x18, 0x03, 0x00, 0x00, 0xf0, 0x00, 0x00, 0x00, + 0x00, 0x08, 0x00, 0x00, 0x2c, 0x0f, 0x00, 0x00, 0x68, 0x0a, 0x00, 0x00, 0xd8, 0x04, 0x00, 0x00, + 0x8c, 0x0e, 0x00, 0x00, 0xcc, 0x05, 0x00, 0x00, 0x80, 0x07, 0x00, 0x00, 0x18, 0x0e, 0x00, 0x00, + 0xa4, 0x0c, 0x00, 0x00, 0x7c, 0x00, 0x00, 0x00, 0xc0, 0x0d, 0x00, 0x00, 0xb0, 0x0b, 0x00, 0x00, + 0x64, 0x05, 0x00, 0x00, 0x68, 0x0d, 0x00, 0x00, 0xc4, 0x08, 0x00, 0x00, 0x3c, 0x04, 0x00, 0x00, + 0x24, 0x10, 0x00, 0x00, 0xcc, 0x0c, 0x00, 0x00, 0xd8, 0x03, 0x00, 0x00, 0xfc, 0x05, 0x00, 0x00, + 0xb0, 0x0a, 0x00, 0x00, 0x8c, 0xef, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0xb4, 0xef, 0xff, 0xff, + 0x3c, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe2, 0xef, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xdc, 0xef, 0xff, 0xff, + 0xd8, 0xef, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff, 0xff, 0xd0, 0x0f, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x2e, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xf0, 0xff, 0xff, 0x24, 0xf0, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, + 0x58, 0xf0, 0xff, 0xff, 0x58, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x44, 0xf0, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x74, 0xf0, 0xff, 0xff, + 0x2c, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa2, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x6c, 0xf0, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x9c, 0xf0, 0xff, 0xff, 0x34, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xf0, 0xff, 0xff, 0xb4, 0xf0, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, + 0x00, 0x00, 0x00, 0x00, 0xdc, 0xf0, 0xff, 0xff, 0x14, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf1, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf1, 0xff, 0xff, 0x00, 0xf1, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, + 0x28, 0xf1, 0xff, 0xff, 0xc8, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x56, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x50, 0xf1, 0xff, 0xff, 0x4c, 0xf1, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x74, 0xf1, 0xff, 0xff, + 0x5c, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa2, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf1, 0xff, 0xff, + 0x98, 0xf1, 0xff, 0xff, 0xcc, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x84, 0xf1, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xb4, 0xf1, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x33, + 0x00, 0x00, 0x00, 0x00, 0xdc, 0xf1, 0xff, 0xff, 0x14, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf2, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf2, 0xff, 0xff, 0x00, 0xf2, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0x24, 0xf2, 0xff, 0xff, + 0xac, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x52, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x4c, 0xf2, 0xff, 0xff, + 0x48, 0xf2, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x30, 0x00, 0x00, 0x7c, 0xf2, 0xff, 0xff, 0x34, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0xf2, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x98, 0xf2, 0xff, 0xff, 0x08, 0x0c, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc6, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x90, 0xf2, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc0, 0xf2, 0xff, 0xff, 0x10, 0x0c, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xdc, 0xf2, 0xff, 0xff, + 0xd8, 0xf2, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x0c, 0xf3, 0xff, 0xff, 0x94, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x3a, 0xf3, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf3, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x34, 0xf3, 0xff, 0xff, + 0x9c, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xf3, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf3, 0xff, 0xff, + 0x58, 0xf3, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, + 0x6c, 0x3a, 0x31, 0x00, 0x8c, 0xf3, 0xff, 0xff, 0x44, 0x0c, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xba, 0xf3, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xb4, 0xf3, 0xff, 0xff, 0xb0, 0xf3, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, + 0x31, 0x00, 0x00, 0x00, 0xdc, 0xf3, 0xff, 0xff, 0xf4, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xf4, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x04, 0xf4, 0xff, 0xff, 0x00, 0xf4, 0xff, 0xff, 0x10, 0x0c, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xec, 0xf3, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x1c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x40, 0xf4, 0xff, 0xff, 0x90, 0x0b, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6e, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x68, 0xf4, 0xff, 0xff, 0x64, 0xf4, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, + 0x31, 0x00, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0xd8, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xf4, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0xf4, 0xff, 0xff, 0x28, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd6, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xd0, 0xf4, 0xff, 0xff, 0xcc, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0xf4, 0xf4, 0xff, 0xff, + 0xdc, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x22, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x1c, 0xf5, 0xff, 0xff, + 0x18, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, + 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x40, 0xf5, 0xff, 0xff, 0xd0, 0x0a, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2c, 0xf5, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x5c, 0xf5, 0xff, 0xff, 0x74, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8a, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x84, 0xf5, 0xff, 0xff, 0x80, 0xf5, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x8c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0xcc, 0x00, 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0xe8, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, + 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, + 0x76, 0x3a, 0x31, 0x00, 0xcc, 0xf5, 0xff, 0xff, 0xd4, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfa, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xc4, 0xf5, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0xf4, 0xf5, 0xff, 0xff, + 0xac, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe4, 0xf5, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0xec, 0xf5, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x1c, 0xf6, 0xff, 0xff, 0xb4, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0xf6, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x44, 0xf6, 0xff, 0xff, 0x40, 0xf6, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x68, 0xf6, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x8c, 0xf6, 0xff, 0xff, 0xa8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x78, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0xf6, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xa0, 0xf6, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0xd0, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, + 0xf8, 0xf6, 0xff, 0xff, 0x18, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe4, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0xf7, 0xff, 0xff, + 0xbc, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x42, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x3c, 0xf7, 0xff, 0xff, + 0x38, 0xf7, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, + 0x68, 0xf7, 0xff, 0xff, 0x38, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x96, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x60, 0xf7, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x90, 0xf7, 0xff, 0xff, 0x40, 0x07, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x80, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xb8, 0xf7, 0xff, 0xff, 0xb4, 0xf7, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, + 0xe8, 0xf7, 0xff, 0xff, 0xb8, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x16, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xe0, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x10, 0xf8, 0xff, 0xff, 0xc0, 0x06, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x38, 0xf8, 0xff, 0xff, 0x34, 0xf8, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, + 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x60, 0xf8, 0xff, 0xff, 0x70, 0x07, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x8e, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x88, 0xf8, 0xff, 0xff, 0x84, 0xf8, 0xff, 0xff, + 0x8c, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x70, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xa0, 0xf8, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, + 0xd8, 0xf8, 0xff, 0xff, 0x5c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf4, 0xf8, 0xff, 0xff, + 0xdc, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe4, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xf9, 0xff, 0xff, + 0x18, 0xf9, 0xff, 0xff, 0x44, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x42, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x34, 0xf9, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, + 0x31, 0x31, 0x00, 0x00, 0x5c, 0xf9, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x54, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x84, 0xf9, 0xff, 0xff, + 0x4c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb2, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xac, 0xf9, 0xff, 0xff, + 0xa8, 0xf9, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x32, 0x33, 0x00, 0x00, 0xdc, 0xf9, 0xff, 0xff, 0xc4, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0a, 0xfa, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xd4, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x04, 0xfa, 0xff, 0xff, + 0x9c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xf0, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x20, 0xfa, 0xff, 0xff, 0xb0, 0x04, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x3c, 0xfa, 0xff, 0xff, + 0x38, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, + 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x74, 0xfa, 0xff, 0xff, + 0x2c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa2, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x6c, 0xfa, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x9c, 0xfa, 0xff, 0xff, 0x34, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8c, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0xc4, 0xfa, 0xff, 0xff, 0xc0, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, 0x00, 0x00, 0x00, 0x00, + 0xe8, 0xfa, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x1e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0xfb, 0xff, 0xff, 0x14, 0xfb, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, + 0x33, 0x00, 0x00, 0x00, 0x3c, 0xfb, 0xff, 0xff, 0x94, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6a, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x64, 0xfb, 0xff, 0xff, 0x60, 0xfb, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x98, 0xfb, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x33, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xca, 0xfb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xbc, 0xfb, 0xff, 0xff, 0xe4, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa8, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xd8, 0xfb, 0xff, 0xff, 0xf8, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc8, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x00, 0xfc, 0xff, 0xff, 0xfc, 0xfb, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x24, 0xfc, 0xff, 0xff, + 0xac, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x52, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x4c, 0xfc, 0xff, 0xff, + 0x48, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x58, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, + 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, + 0x88, 0xfc, 0xff, 0xff, 0x18, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb6, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x80, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xb0, 0xfc, 0xff, 0xff, 0x20, 0x02, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xa0, 0xfc, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xd8, 0xfc, 0xff, 0xff, 0xd4, 0xfc, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, + 0x31, 0x00, 0x00, 0x00, 0xfc, 0xfc, 0xff, 0xff, 0xd4, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2a, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x24, 0xfd, 0xff, 0xff, 0x20, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, + 0x48, 0xfd, 0xff, 0xff, 0x88, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x76, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x70, 0xfd, 0xff, 0xff, 0x6c, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x33, 0x00, 0x94, 0xfd, 0xff, 0xff, + 0x7c, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x80, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xb0, 0xfd, 0xff, 0xff, 0x20, 0x02, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xde, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd8, 0xfd, 0xff, 0xff, 0xd4, 0xfd, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, + 0xf8, 0xfd, 0xff, 0xff, 0xd8, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x26, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x20, 0xfe, 0xff, 0xff, 0x1c, 0xfe, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0x44, 0xfe, 0xff, 0xff, + 0x8c, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x72, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x6c, 0xfe, 0xff, 0xff, + 0x68, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, + 0x98, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x32, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xce, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x98, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0xc8, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xc0, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xf8, 0xfe, 0xff, 0xff, 0xf4, 0xfe, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, + 0x00, 0x00, 0x00, 0x00, 0x2c, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x24, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x54, 0xff, 0xff, 0xff, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x86, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x78, 0xff, 0xff, 0xff, 0x58, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0xa0, 0xff, 0xff, 0xff, 0x9c, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x33, 0x00, 0x00, 0x00, + 0xc8, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 334ecb3887d14..8b41460ccce21 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -41,6 +41,7 @@ #include "core/graph/function_impl.h" #include "core/graph/schema_registry.h" #include "onnx/checker.h" +#include "onnx/defs/parser.h" using namespace ONNX_NAMESPACE::checker; #endif diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h index 1e02e1264e09a..31f5a32a5af2e 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -81,7 +81,7 @@ class OptionalPredicatedTileAccessIterator { Base base_; /// Default constructor - OptionalPredicatedTileAccessIterator() : base_(){}; + OptionalPredicatedTileAccessIterator() : base_() {}; /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID @@ -222,7 +222,7 @@ class OptionalPredicatedTileAccessIterator::value* ThreadMap_::kElementsPerAccess / 8> + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> class OptionalRegularTileAccessIterator { public: using Shape = Shape_; diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h index 270d1f0944989..3bdbdadc474c2 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -330,7 +330,7 @@ class QuantBMmaTensorOp { FragmentB const& B, FragmentQScale const& scales, FragmentQOffset const& offsets) const { - Array const* ptr_B = + Array const* ptr_B = reinterpret_cast const*>(&B); IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); } diff --git a/onnxruntime/core/mickey/gemm/device/quant_b4_gemm.h b/onnxruntime/core/mickey/gemm/device/quant_b4_gemm.h index 6fa4c2b03b4bd..bd7472a9099fd 100644 --- a/onnxruntime/core/mickey/gemm/device/quant_b4_gemm.h +++ b/onnxruntime/core/mickey/gemm/device/quant_b4_gemm.h @@ -18,23 +18,21 @@ #include "gemm/kernel/quant_b4_gemm.h" - namespace mickey { namespace gemm { namespace device { /** * @brief Kernel launcher for quantized GEMM with B matrix quantized to 4bits. -*/ + */ template < - typename QuantBlocking_, ///! Shape of the quantization block, either 1xb or bx1 - bool has_quant_offset, ///! Whether to use quantization offset - typename WarpShape_, ///! Warp-scoped matrix multiply-accumulate - int SplitKSerial_ = 1, ///! How many warps to split the K dimension in the same MxN block - int Stages_ = 3 ///! Stages of the pipelined mainloop -> + typename QuantBlocking_, ///! Shape of the quantization block, either 1xb or bx1 + bool has_quant_offset, ///! Whether to use quantization offset + typename WarpShape_, ///! Warp-scoped matrix multiply-accumulate + int SplitKSerial_ = 1, ///! How many warps to split the K dimension in the same MxN block + int Stages_ = 3 ///! Stages of the pipelined mainloop + > class QuantB4Gemm { -public: - + public: using QuantBlocking = QuantBlocking_; using WarpShape = WarpShape_; static const int kSplitK = SplitKSerial_; @@ -44,19 +42,18 @@ class QuantB4Gemm { using Args = typename Kernel::Params; static cutlass::Status run( - cudaStream_t stream, - cutlass::gemm::GemmCoord const & problem_size, - void* ptr_output, - size_t output_byte_stride, - void const *ptr_a, - size_t a_byte_stride, - void const *ptr_packed_b, - size_t b_byte_stride, - void const *ptr_scales, - size_t scales_byte_stride, - void const *ptr_zp = nullptr, - size_t zp_byte_stride = 0) { - + cudaStream_t stream, + cutlass::gemm::GemmCoord const& problem_size, + void* ptr_output, + size_t output_byte_stride, + void const* ptr_a, + size_t a_byte_stride, + void const* ptr_packed_b, + size_t b_byte_stride, + void const* ptr_scales, + size_t scales_byte_stride, + void const* ptr_zp = nullptr, + size_t zp_byte_stride = 0) { Args args(problem_size, ptr_output, output_byte_stride, ptr_a, a_byte_stride, ptr_packed_b, b_byte_stride, ptr_scales, scales_byte_stride, @@ -89,10 +86,8 @@ class QuantB4Gemm { return cutlass::Status::kSuccess; } - }; - } // namespace device } // namespace gemm } // namespace mickey diff --git a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h index a0695dbbfd347..3e1ed51c81881 100644 --- a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h +++ b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h @@ -29,15 +29,15 @@ namespace gemm { namespace kernel { #if defined(_MSC_VER) && !defined(__clang__) - #pragma warning(push) - #pragma warning(disable:4200) +#pragma warning(push) +#pragma warning(disable : 4200) #endif -template -struct MmaLoopSharedBuffer{ +template +struct MmaLoopSharedBuffer { // Quantized weights are packed int4, each 16x16 tile of int4 // is packed into 8x8 tile of 16b (i.e. 8x16 tile of bytes) - using PackedBShape = cutlass::MatrixShape; + using PackedBShape = cutlass::MatrixShape; static_assert(sizeof(ElementT) == 2, "Only support 16b float types."); /// Buffer for prepacked weights @@ -69,12 +69,12 @@ struct MmaLoopSharedBuffer{ * @brief Fused GEMM kernel for fp16 x int4, where B matrix is blockwise quantized to 4bits. */ template < - typename QuantBlocking_, ///! Shape of the quantization block, either 1xb or bx1 - bool has_quant_offset_, ///! Whether the quantization has offset - typename WarpShape_, ///! Warp-scoped matrix multiply-accumulate - int SplitKSerial_ = 1, ///! How many warps to split the K dimension in the same MxN block - int Stages_ = 3 ///! Stages of the pipelined mainloop -> + typename QuantBlocking_, ///! Shape of the quantization block, either 1xb or bx1 + bool has_quant_offset_, ///! Whether the quantization has offset + typename WarpShape_, ///! Warp-scoped matrix multiply-accumulate + int SplitKSerial_ = 1, ///! How many warps to split the K dimension in the same MxN block + int Stages_ = 3 ///! Stages of the pipelined mainloop + > struct QuantB4Gemm { public: // @@ -94,19 +94,17 @@ struct QuantB4Gemm { // Type constraints verifications: // static_assert(kSplitK > 0 && ((kSplitK - 1) & kSplitK) == 0, - "SplitK must be positive and a power of 2"); + "SplitK must be positive and a power of 2"); static_assert(kStages > 1, "Number of pipeline stages must be greater than 1."); static_assert(kElementSize == 2, "Only support 16b float types."); static_assert(WarpShape::kN % 16 == 0, - "Weight B is packed as 16x16 tiles, warp shape must contain whole tiles!"); + "Weight B is packed as 16x16 tiles, warp shape must contain whole tiles!"); static_assert(WarpShape::kK % 32 == 0, - "K stride too small leading to inefficient global memory load!"); + "K stride too small leading to inefficient global memory load!"); // Need to explore the way to relax this for very small m value. - static_assert((WarpShape::kM % InstructionShape::kM == 0) - && (WarpShape::kN % InstructionShape::kN == 0) - && (WarpShape::kK % InstructionShape::kK == 0), + static_assert((WarpShape::kM % InstructionShape::kM == 0) && (WarpShape::kN % InstructionShape::kN == 0) && (WarpShape::kK % InstructionShape::kK == 0), "Warp shape must be multiple of instruction shape!"); /// switches for debug print @@ -117,7 +115,7 @@ struct QuantB4Gemm { using ATileLoader = mickey::gemm::warp::SwizzleTileLoader; using MetaLoader = mickey::gemm::warp::QuantBScaleLoader; - using WarpPackedBShape = cutlass::gemm::GemmShape<1, WarpShape::kN/2, WarpShape::kK>; + using WarpPackedBShape = cutlass::gemm::GemmShape<1, WarpShape::kN / 2, WarpShape::kK>; using PackedBLoader = mickey::gemm::warp::SwizzleTileLoader; // Need 4 tiles to fully utilize ldmatrix. And..... @@ -154,7 +152,7 @@ struct QuantB4Gemm { static constexpr int kFragPackedBStrideK = WarpPackedBShape::kN == 8 ? 64 : 32; static_assert(WarpShape::kK % kFragPackedBStrideK == 0); - static constexpr int kWarps = kSplitK; // TODO! more warps when we have a larger thread block shape + static constexpr int kWarps = kSplitK; // TODO! more warps when we have a larger thread block shape static int const kThreadCount = 32 * kWarps; using MainLoopSharedBuffer = MmaLoopSharedBuffer; @@ -249,50 +247,48 @@ struct QuantB4Gemm { cutlass::gemm::GemmCoord grid_tiled_shape_; void* const ptr_output_; const size_t output_byte_stride_; - void const * const ptr_a_; + void const* const ptr_a_; const size_t a_byte_stride_; - void const * const ptr_packed_b_; + void const* const ptr_packed_b_; const size_t b_byte_stride_; - void const * const ptr_scales_; + void const* const ptr_scales_; const size_t scales_byte_stride_; - void const * const ptr_offsets_; + void const* const ptr_offsets_; const size_t offsets_byte_stride_; int gemm_k_size_{0}; CUTLASS_HOST_DEVICE - Params() { } + Params() {} CUTLASS_HOST_DEVICE Params( - cutlass::gemm::GemmCoord const & problem_size, - void* ptr_output, - size_t output_byte_stride, - void const *ptr_a, - size_t a_byte_stride, - void const *ptr_packed_b, - size_t b_byte_stride, - void const *ptr_scales, - size_t scales_byte_stride, - void const *ptr_offsets = nullptr, - size_t offsets_byte_stride = 0 - ): - problem_size_(problem_size), - ptr_output_(ptr_output), - output_byte_stride_(output_byte_stride), - ptr_a_(ptr_a), - a_byte_stride_(a_byte_stride), - ptr_packed_b_(ptr_packed_b), - b_byte_stride_(b_byte_stride), - ptr_scales_(ptr_scales), - scales_byte_stride_(scales_byte_stride), - ptr_offsets_(ptr_offsets), - offsets_byte_stride_(offsets_byte_stride), - gemm_k_size_(mickey::round_up(mickey::div_up(problem_size.k(), kSplitK), WarpShape::kK)), - // TODO! grid_tiled_shape_ should be based on thread block shape - grid_tiled_shape_(cutlass::gemm::GemmCoord( - mickey::div_up(problem_size.m(), WarpShape::kM), - mickey::div_up(problem_size.n(), WarpShape::kN), - 1)) { } + cutlass::gemm::GemmCoord const& problem_size, + void* ptr_output, + size_t output_byte_stride, + void const* ptr_a, + size_t a_byte_stride, + void const* ptr_packed_b, + size_t b_byte_stride, + void const* ptr_scales, + size_t scales_byte_stride, + void const* ptr_offsets = nullptr, + size_t offsets_byte_stride = 0) : problem_size_(problem_size), + ptr_output_(ptr_output), + output_byte_stride_(output_byte_stride), + ptr_a_(ptr_a), + a_byte_stride_(a_byte_stride), + ptr_packed_b_(ptr_packed_b), + b_byte_stride_(b_byte_stride), + ptr_scales_(ptr_scales), + scales_byte_stride_(scales_byte_stride), + ptr_offsets_(ptr_offsets), + offsets_byte_stride_(offsets_byte_stride), + gemm_k_size_(mickey::round_up(mickey::div_up(problem_size.k(), kSplitK), WarpShape::kK)), + // TODO! grid_tiled_shape_ should be based on thread block shape + grid_tiled_shape_(cutlass::gemm::GemmCoord( + mickey::div_up(problem_size.m(), WarpShape::kM), + mickey::div_up(problem_size.n(), WarpShape::kN), + 1)) {} }; // @@ -300,10 +296,10 @@ struct QuantB4Gemm { // CUTLASS_HOST_DEVICE - QuantB4Gemm() { } + QuantB4Gemm() {} /// Determines whether kernel satisfies alignment - static cutlass::Status can_implement(const Params ¶ms) { + static cutlass::Status can_implement(const Params& params) { if (params.output_byte_stride_ >= std::numeric_limits::max() || params.a_byte_stride_ >= std::numeric_limits::max() || params.b_byte_stride_ >= std::numeric_limits::max() || @@ -321,7 +317,7 @@ struct QuantB4Gemm { return cutlass::Status::kErrorMisalignedOperand; } if ((params.problem_size_.k() % QuantBlocking::kRow != 0) || - (params.problem_size_.n() % QuantBlocking::kColumn) != 0){ + (params.problem_size_.n() % QuantBlocking::kColumn) != 0) { std::cerr << "QuantB4Gemm validation fail: partial quantization block not supported!" << std::endl; return cutlass::Status::kErrorInvalidProblem; } @@ -384,7 +380,7 @@ struct QuantB4Gemm { return cutlass::Status::kErrorInvalidProblem; } - if constexpr (kSplitK > 1){ + if constexpr (kSplitK > 1) { // TODO! Use thread block shape if (params.gemm_k_size_ < WarpShape::kK * kStages * 2) { // spliting too small, may not get enough iterations to rampup pipeline @@ -398,10 +394,10 @@ struct QuantB4Gemm { /// Executes one GEMM CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { + void operator()(Params const& params, SharedStorage& shared_storage) { // Early exit if CTA is out of range if (params.grid_tiled_shape_.m() <= blockIdx.x || - params.grid_tiled_shape_.n() <= blockIdx.y) { + params.grid_tiled_shape_.n() <= blockIdx.y) { // should not happen if (threadIdx.x == 0) { printf("CTA out of range %d, %d\n", blockIdx.x, blockIdx.y); @@ -422,14 +418,14 @@ struct QuantB4Gemm { assert_pass = false; if (lane_idx == 0) { printf("warp_idx %d exceeds kWarps %d! Should use %d threads per threadblock for kernel launch!\n", - warp_idx, kWarps, kThreadCount); + warp_idx, kWarps, kThreadCount); } } if (warp_idx_k != warp_idx) { assert_pass = false; if (lane_idx == 0) { printf("warp_idx_k %d should be equal to warp_idx %d while we don't yet specify thread block shape larger than warp shape!\n", - warp_idx_k, warp_idx); + warp_idx_k, warp_idx); } } assert(assert_pass); @@ -440,7 +436,7 @@ struct QuantB4Gemm { // so lead dimension byte size is coincidentally k/2 * 2 = k // and next dimension size is n/2 // - const int n_start = mul_power2(blockIdx.y); // TODO! change to thread block shape + const int n_start = mul_power2(blockIdx.y); // TODO! change to thread block shape const int n_end = min(params.problem_size_.n(), mul_power2(blockIdx.y + 1)); const int packed_n_start = (n_start) >> 1; const int packed_n_end = n_end >> 1; @@ -452,34 +448,34 @@ struct QuantB4Gemm { const int m_end = min(params.problem_size_.m(), mul_power2(blockIdx.x + 1)); PackedBLoader packed_b_loader{ - params.ptr_packed_b_, - int(params.b_byte_stride_), - packed_n_start, - packed_n_end, - k_start, - k_end, - lane_idx}; + params.ptr_packed_b_, + int(params.b_byte_stride_), + packed_n_start, + packed_n_end, + k_start, + k_end, + lane_idx}; MetaLoader meta_loader{ - lane_idx, - params.ptr_scales_, - int(params.scales_byte_stride_), - params.ptr_offsets_, - int(params.offsets_byte_stride_), - n_start, n_end}; + lane_idx, + params.ptr_scales_, + int(params.scales_byte_stride_), + params.ptr_offsets_, + int(params.offsets_byte_stride_), + n_start, n_end}; ATileLoader a_tile_loader{ - params.ptr_a_, - int(params.a_byte_stride_), - m_start, m_end, - mul_power2(k_start), mul_power2(k_end), // convert to byte based index - lane_idx}; + params.ptr_a_, + int(params.a_byte_stride_), + m_start, m_end, + mul_power2(k_start), mul_power2(k_end), // convert to byte based index + lane_idx}; // // Prologue: start loading from global memory to shared memory // - int load_k = k_start; // current k index for loading from global memory to shared memory + int load_k = k_start; // current k index for loading from global memory to shared memory int smem_write_stage = 0; uint8_t* packed_b_shared_ptr = shared_storage.smem[warp_idx].main_loop.shared_B.data(); ElementT* a_shared_ptr = shared_storage.smem[warp_idx].main_loop.shared_A.data(); @@ -489,7 +485,7 @@ struct QuantB4Gemm { if constexpr (kDebugPrintSteps) { if (lane_idx == 0) { printf("Warp: %d, m_start %d, m_end %d, n_start %d, n_end %d, k_start %d, k_end %d, packed_n_start %d, packed_n_end %d\n PackedB: %p, A: %p, Scales: %p\n", - warp_idx, m_start, m_end, n_start, n_end, k_start, k_end, packed_n_start, packed_n_end, packed_b_shared_ptr, a_shared_ptr, scales_shared_ptr); + warp_idx, m_start, m_end, n_start, n_end, k_start, k_end, packed_n_start, packed_n_end, packed_b_shared_ptr, a_shared_ptr, scales_shared_ptr); } } @@ -561,7 +557,7 @@ struct QuantB4Gemm { if constexpr (kDebugPrintSteps) { if (lane_idx == 0) { printf("Prefix: PackedB[%d] <- %p <- %p, A[%d] <- %p <- %p, fragment_scales[%d] <- load_k %d <- %p <- %p\n", - 0, packed_b_smem_read_ptr, packed_b_smem_write_ptr, 0, a_smem_read_ptr, a_smem_write_ptr, 0, load_k, scales_smem_read_ptr, scales_smem_write_ptr); + 0, packed_b_smem_read_ptr, packed_b_smem_write_ptr, 0, a_smem_read_ptr, a_smem_write_ptr, 0, load_k, scales_smem_read_ptr, scales_smem_write_ptr); } } @@ -592,8 +588,7 @@ struct QuantB4Gemm { // Main loop // proc_k = load_k - (kStages - 1) * WarpShape::kK // - while (load_k < k_end + (kStages - 1) * WarpShape::kK){ - + while (load_k < k_end + (kStages - 1) * WarpShape::kK) { // One stage has kMmaIterations, we unroll the main loop by 2, // as the meta data is loaded only once every stage, need 2 stages // to complete a double buffer cycle. This is necessary to make @@ -609,13 +604,13 @@ struct QuantB4Gemm { const int read_stage_diff = (smem_write_stage == (kStages - 2)) ? (1 - kStages) : 1; smem_write_stage = (smem_write_stage + 1) % kStages; scales_smem_write_ptr = const_cast(scales_smem_read_ptr); - if constexpr(has_quant_offset) { + if constexpr (has_quant_offset) { offsets_smem_write_ptr = const_cast(offsets_smem_read_ptr); } packed_b_smem_write_ptr = const_cast(packed_b_smem_read_ptr); a_smem_write_ptr = const_cast(a_smem_read_ptr); scales_smem_read_ptr += read_stage_diff * MainLoopSharedBuffer::kMetaSizePerIter; - if constexpr(has_quant_offset) { + if constexpr (has_quant_offset) { offsets_smem_read_ptr += read_stage_diff * MainLoopSharedBuffer::kMetaSizePerIter; } packed_b_smem_read_ptr += read_stage_diff * MainLoopSharedBuffer::kPackedBSizePerIter; @@ -637,10 +632,10 @@ struct QuantB4Gemm { fragment_scales[(next_iter2 / kMmaIterations) % 2], scales_smem_read_ptr, fragment_offsets[(next_iter2 / kMmaIterations) % 2], offsets_smem_read_ptr); - if constexpr(kDebugPrintB) { + if constexpr (kDebugPrintB) { if (lane_idx == 0) { printf("Mainloop, warp: %d, proc_k %d, load_k %d\nWritePtr: %p, ReadPtr: %p\n", - warp_idx, load_k - (kStages - 1) * WarpShape::kK, load_k, packed_b_smem_write_ptr, packed_b_smem_read_ptr); + warp_idx, load_k - (kStages - 1) * WarpShape::kK, load_k, packed_b_smem_write_ptr, packed_b_smem_read_ptr); } cutlass::debug::dump_shmem(packed_b_shared_ptr, MainLoopSharedBuffer::kPackedBSize); } @@ -668,7 +663,7 @@ struct QuantB4Gemm { if constexpr (kDebugPrintSteps) { if (lane_idx == 0) { - printf("A[%d] <- %p <- %p\n", (iter2 + 1) % 2, a_smem_read_ptr, a_smem_write_ptr); + printf("A[%d] <- %p <- %p\n", (iter2 + 1) % 2, a_smem_read_ptr, a_smem_write_ptr); } } a_tile_loader.load_fragment_k32(lane_idx, a_smem_read_ptr, next_iter * InstructionShape::kK * kElementSize, fragment_a[(iter2 + 1) % 2].data()); @@ -682,7 +677,8 @@ struct QuantB4Gemm { if (lane_id == 0) { printf("==== A tiles =======\n"); } - const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" : "%f, %f, "; + const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" + : "%f, %f, "; const ElementT* a_ptr = fragment_a[iter2 % 2].data(); for (int m2_tile = 0; m2_tile < (WarpShape::kM / InstructionShape::kM); ++m2_tile, a_ptr += 8) { printf(format, float(a_ptr[0]), float(a_ptr[1])); @@ -718,7 +714,8 @@ struct QuantB4Gemm { if (lane_id == 0) { printf("======= C tiles in warp %d =======\n", warp_idx); } - const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" : "%f, %f, "; + const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" + : "%f, %f, "; for (int n_tile = 0; n_tile < (WarpShape::kN / InstructionShape::kN); ++n_tile) { for (int m_tile = 0; m_tile < (WarpShape::kM / InstructionShape::kM); ++m_tile, c_ptr += 4) { // since InstructionShape::kM is 16, we can print 2 tiles @@ -739,7 +736,7 @@ struct QuantB4Gemm { using Float4 = cutlass::Array; // hopefully utilize 128b st.shared.b128 constexpr int kAccLoads = MmaOp::FragmentC::kElements / 4; static_assert(kAccLoads * 4 == MmaOp::FragmentC::kElements); - if (warp_idx != 0){ + if (warp_idx != 0) { Float4* d_smem_ptr = reinterpret_cast(shared_storage.smem[warp_idx].shared_Acc.data()); d_smem_ptr += lane_idx; Float4* f4s = reinterpret_cast(accumulators.data()); @@ -784,7 +781,8 @@ struct QuantB4Gemm { d_smem_ptr += 32; if constexpr (kDebugPrintC) { - const char* const format = (lane_idx == 31) ? "%f, %f\n\n" : ((lane_idx % 4) == 3) ? "%f, %f\n" : "%f, %f, "; + const char* const format = (lane_idx == 31) ? "%f, %f\n\n" : ((lane_idx % 4) == 3) ? "%f, %f\n" + : "%f, %f, "; printf(format, float(other_acc[double_buffer_idx][0]), float(other_acc[double_buffer_idx][1])); printf(format, float(other_acc[double_buffer_idx][2]), float(other_acc[double_buffer_idx][3])); } @@ -822,15 +820,13 @@ struct QuantB4Gemm { CUTLASS_PRAGMA_UNROLL for (int m_tile = 0; m_tile < (WarpShape::kM / 8); ++m_tile, m += 8, ++c_ptr) { if (n < n_end && m < m_end) { - *(output_ptr + m * output_stride + n/2) = __float22half2_rn(c_ptr[0]); + *(output_ptr + m * output_stride + n / 2) = __float22half2_rn(c_ptr[0]); } } } - } }; - } // namespace kernel } // namespace gemm } // namespace mickey diff --git a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h index 79c582279f2c8..d2e0cfdc31d1a 100644 --- a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h +++ b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h @@ -23,11 +23,10 @@ namespace detail { /** * @brief Convert (4b weights - 8) to fp16 using bits operations. -*/ + */ CUTLASS_DEVICE -void weightsMinuEight2Half(uint32_t const &weights, - cutlass::Array& dest) -{ +void weightsMinuEight2Half(uint32_t const& weights, + cutlass::Array& dest) { // 4b weights are arranged as [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent // weights are in adjacent 16b half words. // w & 0x000f000f --> take out element 0, 1 @@ -47,24 +46,23 @@ void weightsMinuEight2Half(uint32_t const &weights, // // 1.125 instruction per weight, 9 instructions in total. - uint32_t* b32s = reinterpret_cast(dest.data()); + uint32_t* b32s = reinterpret_cast(dest.data()); const uint32_t high_8s = weights >> 8; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500)) asm volatile( - " lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n" - " lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n" - " lop3.b32 %2, %5, 0x000f000f, %6, 0xea;\n" - " lop3.b32 %3, %5, 0x00f000f0, %7, 0xea;\n" - " sub.rn.f16x2 %0, %0, %10;\n" // q_w - 1032.0 - " fma.rn.f16x2 %1, %1, %8, %9;\n" // 1.0 * q_w + (-72.0) - " sub.rn.f16x2 %2, %2, %10;\n" - " fma.rn.f16x2 %3, %3, %8, %9;\n" - : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) - : "r"(weights), "r"(high_8s), - "r"(0x64006400), "r"(0x54005400) - "r"(0x3c003c00), "r"(0xd480d480), - "r"(0x64086408)); + " lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n" + " lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n" + " lop3.b32 %2, %5, 0x000f000f, %6, 0xea;\n" + " lop3.b32 %3, %5, 0x00f000f0, %7, 0xea;\n" + " sub.rn.f16x2 %0, %0, %10;\n" // q_w - 1032.0 + " fma.rn.f16x2 %1, %1, %8, %9;\n" // 1.0 * q_w + (-72.0) + " sub.rn.f16x2 %2, %2, %10;\n" + " fma.rn.f16x2 %3, %3, %8, %9;\n" + : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) + : "r"(weights), "r"(high_8s), + "r"(0x64006400), "r"(0x54005400) "r"(0x3c003c00), "r"(0xd480d480), + "r"(0x64086408)); #else assert(false); #endif @@ -72,11 +70,10 @@ void weightsMinuEight2Half(uint32_t const &weights, /** * @brief Convert 4b weights to fp16 using bits operations. -*/ + */ CUTLASS_DEVICE -void weights2Half([[maybe_unused]] uint32_t const &weights, - cutlass::Array& dest) -{ +void weights2Half([[maybe_unused]] uint32_t const& weights, + cutlass::Array& dest) { // 4b weights are arranged as [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent // weights are in adjacent 16b half words. // w & 0x000f000f --> take out element 0, 1 @@ -96,23 +93,22 @@ void weights2Half([[maybe_unused]] uint32_t const &weights, // // 1.125 instruction per weight, 9 instructions in total. - uint32_t* b32s = reinterpret_cast(dest.data()); + uint32_t* b32s = reinterpret_cast(dest.data()); const uint32_t high_8s = weights >> 8; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500)) asm volatile( - " lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n" - " lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n" - " lop3.b32 %2, %5, 0x000f000f, %6, 0xea;\n" - " lop3.b32 %3, %5, 0x00f000f0, %7, 0xea;\n" - " sub.rn.f16x2 %0, %0, %6;\n" // q_w - 1024.0 - " fma.rn.f16x2 %1, %1, %8, %9;\n" // 1.0 * q_w + (-64.0) - " sub.rn.f16x2 %2, %2, %6;\n" - " fma.rn.f16x2 %3, %3, %8, %9;\n" - : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) - : "r"(weights), "r"(high_8s), - "r"(0x64006400), "r"(0x54005400) - "r"(0x3c003c00), "r"(0xd400d400)); + " lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n" + " lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n" + " lop3.b32 %2, %5, 0x000f000f, %6, 0xea;\n" + " lop3.b32 %3, %5, 0x00f000f0, %7, 0xea;\n" + " sub.rn.f16x2 %0, %0, %6;\n" // q_w - 1024.0 + " fma.rn.f16x2 %1, %1, %8, %9;\n" // 1.0 * q_w + (-64.0) + " sub.rn.f16x2 %2, %2, %6;\n" + " fma.rn.f16x2 %3, %3, %8, %9;\n" + : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) + : "r"(weights), "r"(high_8s), + "r"(0x64006400), "r"(0x54005400) "r"(0x3c003c00), "r"(0xd400d400)); #else assert(false); #endif @@ -124,10 +120,10 @@ void weights2Half([[maybe_unused]] uint32_t const &weights, /// Loader for blockwise quantization scales template < - typename QuantBlocking_, ///! Shape of the quant block (concept: MatrixShape) - typename WarpShape_, ///! Shape of the warp tile (concept: GemmShape kM ignored) + typename QuantBlocking_, ///! Shape of the quant block (concept: MatrixShape) + typename WarpShape_, ///! Shape of the warp tile (concept: GemmShape kM ignored) typename ElementT_ = cutlass::half_t, ///! Data type of the scales and dequantized B - bool has_offsets = false, ///! Whether the quantization has offsets + bool has_offsets = false, ///! Whether the quantization has offsets bool DebugPrint = false> struct QuantBScaleLoader; @@ -170,44 +166,42 @@ struct QuantBScaleLoader, WarpShape_, Eleme // HBM -> SMEM, 16 bytes per load, no leftover since WarpShape::kN is multiple of 16 static constexpr int kSmemSize = WarpShape::kN * kMetaChunkCount; static constexpr int kScaleLoadThreads = (kSmemSize * sizeof(ElementT)) / 16; - static_assert(kScaleLoadThreads <= 16); // shape up to 64x64, 16 threads can load all scales + static_assert(kScaleLoadThreads <= 16); // shape up to 64x64, 16 threads can load all scales using FragmentScales = cutlass::Array; static constexpr int kOffsetLoadThreads = (kSmemSize * sizeof(OffsetT)) / 16; - static_assert(kOffsetLoadThreads <= 16); // shape up to 64x64, 16 threads can load all offsets + static_assert(kOffsetLoadThreads <= 16); // shape up to 64x64, 16 threads can load all offsets using FragmentOffsets = typename std::conditional::type; + FragmentScales, + std::monostate>::type; // // Data members // const int n_cnt; - const uint8_t * const scales_byte_p; + const uint8_t* const scales_byte_p; const int scales_byte_stride; - const uint8_t * const offsets_byte_p; + const uint8_t* const offsets_byte_p; const int offsets_byte_stride; // // Methods // template - CUTLASS_DEVICE - static const uint8_t* get_scales_p(const void* ptr_scales, int scales_byte_stride, int k, int n) { - return (ptr_scales == nullptr) ? nullptr : - reinterpret_cast(ptr_scales) + k * scales_byte_stride + n * sizeof(T); + CUTLASS_DEVICE static const uint8_t* get_scales_p(const void* ptr_scales, int scales_byte_stride, int k, int n) { + return (ptr_scales == nullptr) ? nullptr : reinterpret_cast(ptr_scales) + k * scales_byte_stride + n * sizeof(T); } /// Initializes the scale loader, pointing to the start of the scales tensor CUTLASS_DEVICE QuantBScaleLoader( int lane_idx, - void const *ptr_scales, + void const* ptr_scales, int scales_byte_stride, - void const *ptr_offsets, // dummy to make the interface consistent with QuantBScaleOffsetLoader + void const* ptr_offsets, // dummy to make the interface consistent with QuantBScaleOffsetLoader int offsets_byte_stride, int start_n, int end_n) @@ -215,12 +209,11 @@ struct QuantBScaleLoader, WarpShape_, Eleme scales_byte_p(get_scales_p(ptr_scales, scales_byte_stride, 0, start_n)), scales_byte_stride(scales_byte_stride), offsets_byte_p(get_scales_p(ptr_offsets, offsets_byte_stride, 0, start_n)), - offsets_byte_stride(offsets_byte_stride) - { + offsets_byte_stride(offsets_byte_stride) { assert(ptr_scales != nullptr); assert(scales_byte_stride > 0 && mod_power2<16>(scales_byte_stride) == 0); assert(scales_byte_stride >= end_n * sizeof(ElementT)); - if constexpr(has_offsets) { + if constexpr (has_offsets) { assert(ptr_offsets != nullptr); assert(offsets_byte_stride > 0 && mod_power2<16>(offsets_byte_stride) == 0); assert(offsets_byte_stride >= end_n * sizeof(OffsetT)); @@ -235,7 +228,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme CUTLASS_DEVICE void load_to_smem(const int lane_idx, const int start_k, const int k_cnt, ElementT* smem, OffsetT* offset_smem) const { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - { // Load scales to smem + { // Load scales to smem int lane_ptr_offset = mul_power2<16 / sizeof(ElementT)>(lane_idx); // Column-wise quantization, every column has its own scale/offset @@ -251,13 +244,12 @@ struct QuantBScaleLoader, WarpShape_, Eleme " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3, %4;\n" - "}\n" - ::"r"((int)(kScaleLoadThreads > lane_idx)), - "r"(smem_int_ptr), - "l"(&scales_ptr[k_idx * scales_byte_stride + n_idx * sizeof(ElementT)]), - "n"(16), "r"(src_in_bytes)); + "}\n" ::"r"((int)(kScaleLoadThreads > lane_idx)), + "r"(smem_int_ptr), + "l"(&scales_ptr[k_idx * scales_byte_stride + n_idx * sizeof(ElementT)]), + "n"(16), "r"(src_in_bytes)); } - if constexpr(has_offsets) { // Load offset to smem + if constexpr (has_offsets) { // Load offset to smem int lane_ptr_offset = mul_power2<16 / sizeof(OffsetT)>(lane_idx); // Column-wise quantization, every column has its own scale/offset @@ -274,26 +266,25 @@ struct QuantBScaleLoader, WarpShape_, Eleme " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3, %4;\n" - "}\n" - ::"r"((int)(kOffsetLoadThreads > lane_idx)), - "r"(smem_int_ptr), - "l"(&offsets_ptr[k_idx * offsets_byte_stride + n_idx * sizeof(OffsetT)]), - "n"(16), "r"(src_in_bytes)); + "}\n" ::"r"((int)(kOffsetLoadThreads > lane_idx)), + "r"(smem_int_ptr), + "l"(&offsets_ptr[k_idx * offsets_byte_stride + n_idx * sizeof(OffsetT)]), + "n"(16), "r"(src_in_bytes)); } #else - assert(false); - (void)(lane_idx); - (void)(start_k); - (void)(k_cnt); - (void)(smem); - (void)(offset_smem); + assert(false); + (void)(lane_idx); + (void)(start_k); + (void)(k_cnt); + (void)(smem); + (void)(offset_smem); #endif } CUTLASS_DEVICE static void load_fragment(const int lane_idx, - FragmentScales &frag_scales, const ElementT* smem, - FragmentOffsets &frag_offsets, const OffsetT* offset_smem) { + FragmentScales& frag_scales, const ElementT* smem, + FragmentOffsets& frag_offsets, const OffsetT* offset_smem) { const int n_idx = div_power2<4>(lane_idx); ElementT const* scales_ptr = smem + n_idx; [[maybe_unused]] OffsetT const* offset_ptr = offset_smem + n_idx; @@ -301,7 +292,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kSmemSize / 8; ++i) { frag_scales[i] = scales_ptr[i << 3]; - if constexpr(has_offsets) { + if constexpr (has_offsets) { uint16_t v = offset_ptr[i << 3]; frag_offsets[i] = cutlass::half_t(__ushort2half_rn(v)); } @@ -315,13 +306,12 @@ struct QuantBScaleLoader, WarpShape_, Eleme /// thus the FragmentB has (WarpShape::kN / 8) * 2 * 2 elements. template - CUTLASS_DEVICE - static void dequant_k16( + CUTLASS_DEVICE static void dequant_k16( const int k_iter, - cutlass::Array const &frag_pack_b, - FragmentScales const &frag_scales, - FragmentOffsets const &frag_offsets, - FragmentB &frag_b) { + cutlass::Array const& frag_pack_b, + FragmentScales const& frag_scales, + FragmentOffsets const& frag_offsets, + FragmentB& frag_b) { // Each 32b number in packed B represent a 16x16 tile constexpr int kPackedBNTiles = WarpShape::kN / 16; constexpr int kPackedBKStride = PackedBSize / kPackedBNTiles; @@ -334,7 +324,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme const int meta_k = k_iter / (QuantBlocking::kRow / 16); half const* scales = reinterpret_cast(frag_scales.data() + meta_k * kMetaFragSize); [[maybe_unused]] half const* offsets = nullptr; - if constexpr(has_offsets) { + if constexpr (has_offsets) { offsets = reinterpret_cast(frag_offsets.data() + meta_k * kMetaFragSize); } @@ -349,7 +339,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme cutlass::Array ws; half2* weight_pair = reinterpret_cast(ws.data()); - if constexpr(has_offsets) { + if constexpr (has_offsets) { detail::weights2Half(frag_pack_b[b_idx], ws); half2 offset_pair = __half2half2(offsets[nn * 2]); half2 offset_pair1 = __half2half2(offsets[nn * 2 + 1]); @@ -392,10 +382,8 @@ struct QuantBScaleLoader, WarpShape_, Eleme } } } - }; - /// Specialization for row-wise quantization, i.e. QuantBlocking::kRow == 1 template < int block_size_, @@ -430,44 +418,42 @@ struct QuantBScaleLoader, WarpShape_, Eleme // HBM -> SMEM, 16 bytes per load, no leftover since WarpShape::kN is multiple of 16 static constexpr int kSmemSize = WarpShape::kK * kMetaChunkCount; static constexpr int kScaleLoadThreads = (kSmemSize * sizeof(ElementT)) / 16; - static_assert(kScaleLoadThreads <= 16); // shape up to 64x64, 16 threads can load all scales + static_assert(kScaleLoadThreads <= 16); // shape up to 64x64, 16 threads can load all scales using FragmentScales = cutlass::Array; static constexpr int kOffsetLoadThreads = (kSmemSize * sizeof(OffsetT)) / 16; - static_assert(kOffsetLoadThreads <= 16); // shape up to 64x64, 16 threads can load all offsets + static_assert(kOffsetLoadThreads <= 16); // shape up to 64x64, 16 threads can load all offsets using FragmentOffsets = typename std::conditional::type; + FragmentScales, + std::monostate>::type; // // Data members // const int n_cnt; - const uint8_t * const scales_byte_p; + const uint8_t* const scales_byte_p; const int scales_byte_stride; - const uint8_t * const offsets_byte_p; + const uint8_t* const offsets_byte_p; const int offsets_byte_stride; // // Methods // template - CUTLASS_DEVICE - static const uint8_t* get_scales_p(const void* ptr_scales, int scales_byte_stride, int k, int n) { - return (ptr_scales == nullptr) ? nullptr : - reinterpret_cast(ptr_scales) + n * scales_byte_stride + k * sizeof(T); + CUTLASS_DEVICE static const uint8_t* get_scales_p(const void* ptr_scales, int scales_byte_stride, int k, int n) { + return (ptr_scales == nullptr) ? nullptr : reinterpret_cast(ptr_scales) + n * scales_byte_stride + k * sizeof(T); } /// Initializes the scale loader, pointing to the start of the scales tensor CUTLASS_DEVICE QuantBScaleLoader( int lane_idx, - void const *ptr_scales, + void const* ptr_scales, int scales_byte_stride, - void const *ptr_offsets, // dummy to make the interface consistent with QuantBScaleOffsetLoader + void const* ptr_offsets, // dummy to make the interface consistent with QuantBScaleOffsetLoader int offsets_byte_stride, int start_n, int end_n) @@ -475,11 +461,10 @@ struct QuantBScaleLoader, WarpShape_, Eleme scales_byte_p(get_scales_p(ptr_scales, scales_byte_stride, 0, start_n / QuantBlocking::kColumn)), scales_byte_stride(scales_byte_stride), offsets_byte_p(get_scales_p(ptr_offsets, offsets_byte_stride, 0, start_n / QuantBlocking::kColumn)), - offsets_byte_stride(offsets_byte_stride) - { + offsets_byte_stride(offsets_byte_stride) { assert(ptr_scales != nullptr); assert(scales_byte_stride > 0 && mod_power2<16>(scales_byte_stride) == 0); - if constexpr(has_offsets) { + if constexpr (has_offsets) { assert(ptr_offsets != nullptr); assert(offsets_byte_stride > 0 && mod_power2<16>(offsets_byte_stride) == 0); } else { @@ -507,13 +492,12 @@ struct QuantBScaleLoader, WarpShape_, Eleme " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3, %4;\n" - "}\n" - ::"r"((int)(kScaleLoadThreads > lane_idx)), - "r"(smem_int_ptr), - "l"(&scales_ptr[n_idx * scales_byte_stride + k_idx * sizeof(ElementT)]), - "n"(16), "r"(src_in_bytes)); + "}\n" ::"r"((int)(kScaleLoadThreads > lane_idx)), + "r"(smem_int_ptr), + "l"(&scales_ptr[n_idx * scales_byte_stride + k_idx * sizeof(ElementT)]), + "n"(16), "r"(src_in_bytes)); } - if constexpr(has_offsets) { + if constexpr (has_offsets) { // Load offsets to smem int lane_ptr_offset = mul_power2<16 / sizeof(OffsetT)>(lane_idx); const uint8_t* offsets_ptr = offsets_byte_p + start_k * sizeof(OffsetT); @@ -527,26 +511,25 @@ struct QuantBScaleLoader, WarpShape_, Eleme " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3, %4;\n" - "}\n" - ::"r"((int)(kOffsetLoadThreads > lane_idx)), - "r"(smem_int_ptr), - "l"(&offsets_ptr[n_idx * offsets_byte_stride + k_idx * sizeof(OffsetT)]), - "n"(16), "r"(src_in_bytes)); + "}\n" ::"r"((int)(kOffsetLoadThreads > lane_idx)), + "r"(smem_int_ptr), + "l"(&offsets_ptr[n_idx * offsets_byte_stride + k_idx * sizeof(OffsetT)]), + "n"(16), "r"(src_in_bytes)); } #else - assert(false); - (void)(lane_idx); - (void)(start_k); - (void)(k_cnt); - (void)(smem); - (void)(offset_smem); + assert(false); + (void)(lane_idx); + (void)(start_k); + (void)(k_cnt); + (void)(smem); + (void)(offset_smem); #endif } CUTLASS_DEVICE static void load_fragment(const int lane_idx, - FragmentScales &frag_scales, const ElementT* smem, - FragmentOffsets &frag_offsets, const OffsetT* offset_smem) { + FragmentScales& frag_scales, const ElementT* smem, + FragmentOffsets& frag_offsets, const OffsetT* offset_smem) { // Row-wise quantization, every row has its own scale/offset, elements have been rearraged // such that we can load two tile at a time. // T0 T0 @@ -560,17 +543,17 @@ struct QuantBScaleLoader, WarpShape_, Eleme const int lane_offset = mod_power2<4>(lane_idx) << 2; const uint32_t* scales_ptr = reinterpret_cast(smem + lane_offset); [[maybe_unused]] const uint32_t* offsets_ptr = nullptr; - if constexpr(has_offsets) { + if constexpr (has_offsets) { offsets_ptr = reinterpret_cast(offset_smem + lane_offset); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentScales::kElements; i += 4) { - uint32_t* frag_ptr = reinterpret_cast(frag_scales.data() + i); + uint32_t* frag_ptr = reinterpret_cast(frag_scales.data() + i); frag_ptr[0] = scales_ptr[0]; frag_ptr[1] = scales_ptr[1]; scales_ptr += 8; - if constexpr(has_offsets) { + if constexpr (has_offsets) { // offsets are always 4 a group, this give us an opportunity to use // a little trick to reduce the number of instructions. // So here 4 offset a, b, c, d, we convert them to fp16, but not quite, @@ -582,11 +565,11 @@ struct QuantBScaleLoader, WarpShape_, Eleme const uint32_t ab = offsets_ptr[0]; const uint32_t cd = ab >> 4; asm volatile( - " lop3.b32 %0, %2, 0x000f000f, %4, 0xea;\n" - " lop3.b32 %1, %3, 0x00f000f0, %5, 0xea;\n" - : "=r"(offset_pair[0]), "=r"(offset_pair[1]) - : "r"(ab), "r"(cd), - "r"(0x64006400), "r"(0x54005400)); + " lop3.b32 %0, %2, 0x000f000f, %4, 0xea;\n" + " lop3.b32 %1, %3, 0x00f000f0, %5, 0xea;\n" + : "=r"(offset_pair[0]), "=r"(offset_pair[1]) + : "r"(ab), "r"(cd), + "r"(0x64006400), "r"(0x54005400)); } offsets_ptr += 4; } @@ -598,14 +581,13 @@ struct QuantBScaleLoader, WarpShape_, Eleme /// Dequantize a block of (16, WarpShape::kN) packed int4 weights to 16b float. /// This block has (WarpShape::kN / 8) * 2 tiles, each tile has 2 elements per thread, /// thus the FragmentB has (WarpShape::kN / 8) * 2 * 2 elements. - template - CUTLASS_DEVICE - static void dequant_k16( + template + CUTLASS_DEVICE static void dequant_k16( const int k_iter, - cutlass::Array const &frag_pack_b, - FragmentScales const &frag_scales, - FragmentOffsets const &frag_offsets, - FragmentB &frag_b) { + cutlass::Array const& frag_pack_b, + FragmentScales const& frag_scales, + FragmentOffsets const& frag_offsets, + FragmentB& frag_b) { // Each 32b number in packed B represent a 16x16 tile constexpr int kPackedBNTiles = WarpShape::kN / 16; constexpr int kPackedBKStride = PackedBSize / kPackedBNTiles; @@ -618,9 +600,9 @@ struct QuantBScaleLoader, WarpShape_, Eleme half2* const fb_pair = reinterpret_cast(frag_b.data() + nn * 8); const int meta_n = (nn * 16) / QuantBlocking::kColumn; const int idx = meta_n * kMetaFragSize + (k_iter * 4); - half2 const* const scale_pair = reinterpret_cast(frag_scales.data() + idx); // k_offset / 16 * 4 + half2 const* const scale_pair = reinterpret_cast(frag_scales.data() + idx); // k_offset / 16 * 4 [[maybe_unused]] half2 const* offsets = nullptr; - if constexpr(has_offsets) { + if constexpr (has_offsets) { offsets = reinterpret_cast(frag_offsets.data() + idx); } cutlass::Array ws; @@ -629,17 +611,17 @@ struct QuantBScaleLoader, WarpShape_, Eleme // a group of 4 offsets was converted to a + 1024.0, b + 1024.0, c + 64.0, d + 64.0 // when loaded from shared memory. { - uint32_t* b32s = reinterpret_cast(ws.data()); + uint32_t* b32s = reinterpret_cast(ws.data()); const uint32_t low_8s = frag_pack_b[b_idx]; const uint32_t high_8s = low_8s >> 8; asm volatile( - " lop3.b32 %0, %4, 0x000f000f, 0x64006400, 0xea;\n" - " lop3.b32 %1, %4, 0x00f000f0, 0x54005400, 0xea;\n" - " lop3.b32 %2, %5, 0x000f000f, 0x64006400, 0xea;\n" - " lop3.b32 %3, %5, 0x00f000f0, 0x54005400, 0xea;\n" - : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) - : "r"(low_8s), "r"(high_8s)); + " lop3.b32 %0, %4, 0x000f000f, 0x64006400, 0xea;\n" + " lop3.b32 %1, %4, 0x00f000f0, 0x54005400, 0xea;\n" + " lop3.b32 %2, %5, 0x000f000f, 0x64006400, 0xea;\n" + " lop3.b32 %3, %5, 0x00f000f0, 0x54005400, 0xea;\n" + : "=r"(b32s[0]), "=r"(b32s[1]), "=r"(b32s[2]), "=r"(b32s[3]) + : "r"(low_8s), "r"(high_8s)); } weight_pair[0] = __hsub2(weight_pair[0], offsets[0]); @@ -681,7 +663,6 @@ struct QuantBScaleLoader, WarpShape_, Eleme } } } - }; } // namespace warp diff --git a/onnxruntime/core/mickey/gemm/warp/swizzle_tile_loader.h b/onnxruntime/core/mickey/gemm/warp/swizzle_tile_loader.h index 6549a441abb45..f28265e1df7b1 100644 --- a/onnxruntime/core/mickey/gemm/warp/swizzle_tile_loader.h +++ b/onnxruntime/core/mickey/gemm/warp/swizzle_tile_loader.h @@ -39,679 +39,674 @@ class SwizzleTileLoader; template class SwizzleTileLoader { - public: - static constexpr int SmemDimM = SmemDimM_; - static constexpr int SmemDimK = 64; - static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes - static constexpr int kBlockSize = SmemDimM * SmemDimK; - static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); - - // Swizzle pattern is 4x8 - static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; - static_assert(kSwizzleK == cute::_4::value); - static constexpr int kSwizzleM = cute::_8::value; - static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; - using Swizzled64 = decltype( - cute::composition(cute::Swizzle<2,0,3>{}, - cute::Layout, - cute::Stride>{})); - - static constexpr int kThreads = 32; - static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; - static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); - static_assert(SmemDimM % kGmemLoadStrideM == 0); - - // During pipelined MMA, each stage (processing a tile) is split - // into multiple mma iterations. We need to somehow split the - // the loading of global memory tile into multiple calls, doing - // our best to help spread these actions across different iterations - // in a stage. - static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; - - private: - /// Pointer to global memory to load data from - uint8_t const* g_ptr_{nullptr}; - /// Iteration boundaries in the M or N dimension - int mn_cnt_{0}; - /// Iteration boundaries in the K dimension, in strides of 16 - int k_cnt_{0}; - /// Stride in bytes to advance to next row in m or n dimension - const int stride_; - - public: - CUTLASS_DEVICE - SwizzleTileLoader( - void const* data_ptr, ///< Pointer to the global memory tiles - int byte_stride, ///< Stride in bytes to advance to next row - int mn_start, ///< Starting position in the M or N dimension - int mn_end, ///< End position in the M or N dimension - int k_start, ///< Starting position in the K dimension - int k_end, ///< End position in the K dimension - int lane_id) ///< ID of each participating thread - : stride_(byte_stride) { - #ifndef NDEBUG - bool assertion_pass = true; - if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); - } - } - if (byte_stride % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); - } - } - if (k_start % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_start: %d is not aligned to 16B boundary!\n", k_start); - } - } - if (k_end % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is not aligned to 16B boundary!\n", k_end); - } - } - if (mn_end <= mn_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); - } - } - if (k_end <= k_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); - } - } - if (lane_id < 0 || lane_id >= kThreads) { - assertion_pass = false; - if (lane_id == 0) { - printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); - } - } - assert(assertion_pass); - #endif - - int lane_m = div_power2(lane_id); - int lane_k = mod_power2(lane_id); - mn_start += lane_m; - k_start += mul_power2(lane_k); - - mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); - k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); - if (mn_cnt_ <= 0 || k_cnt_ <= 0) { - mn_cnt_ = 0; - k_cnt_ = 0; - g_ptr_ = nullptr; - return; - } - g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; - // if (lane_id == 0) - // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); - } - - /** - * @brief Load a row major tile (SmemDimM, 64) from global memory to shared memory - */ - CUTLASS_DEVICE - void load_to_smem(const int lane_id, void* smem) { - // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size - static_assert(kGmemLoadStrideM == kSwizzleM); - const uint8_t* data_ptr = g_ptr_; - uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(Swizzled64{}(lane_id)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SmemDimM / kSwizzleM; ++i) { - cutlass::arch::cp_async( - smem_ptr, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); - data_ptr += mul_power2(stride_); - smem_ptr += kSwizzleTileSize * kLoadVectorSize; - } - } - - CUTLASS_DEVICE - void load_to_smem_split(const int lane_id, void* smem, const int split_idx) { - // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size - static_assert(kGmemLoadStrideM == kSwizzleM); - - const uint8_t* split_ptr = g_ptr_ + mul_power2(split_idx * stride_); - uint8_t* split_smem_ptr = reinterpret_cast(smem) + mul_power2(Swizzled64{}(lane_id)) + split_idx * kSwizzleTileSize * kLoadVectorSize; - - cutlass::arch::cp_async( - split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); - } - - /** - * @brief Advance global memory pointer to the next tile in the K dimension - */ - CUTLASS_DEVICE - SwizzleTileLoader& operator++() { - --k_cnt_; - if (k_cnt_ > 0) { - g_ptr_ += kLoadVectorSize * kSwizzleK; - } else { - g_ptr_ = nullptr; - } - return *this; - } - - /** - * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, - * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) - */ - CUTLASS_DEVICE - void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { + public: + static constexpr int SmemDimM = SmemDimM_; + static constexpr int SmemDimK = 64; + static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes + static constexpr int kBlockSize = SmemDimM * SmemDimK; + static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); + + // Swizzle pattern is 4x8 + static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; + static_assert(kSwizzleK == cute::_4::value); + static constexpr int kSwizzleM = cute::_8::value; + static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; + using Swizzled64 = decltype(cute::composition(cute::Swizzle<2, 0, 3>{}, + cute::Layout, + cute::Stride>{})); + + static constexpr int kThreads = 32; + static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; + static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); + static_assert(SmemDimM % kGmemLoadStrideM == 0); + + // During pipelined MMA, each stage (processing a tile) is split + // into multiple mma iterations. We need to somehow split the + // the loading of global memory tile into multiple calls, doing + // our best to help spread these actions across different iterations + // in a stage. + static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; + + private: + /// Pointer to global memory to load data from + uint8_t const* g_ptr_{nullptr}; + /// Iteration boundaries in the M or N dimension + int mn_cnt_{0}; + /// Iteration boundaries in the K dimension, in strides of 16 + int k_cnt_{0}; + /// Stride in bytes to advance to next row in m or n dimension + const int stride_; + + public: + CUTLASS_DEVICE + SwizzleTileLoader( + void const* data_ptr, ///< Pointer to the global memory tiles + int byte_stride, ///< Stride in bytes to advance to next row + int mn_start, ///< Starting position in the M or N dimension + int mn_end, ///< End position in the M or N dimension + int k_start, ///< Starting position in the K dimension + int k_end, ///< End position in the K dimension + int lane_id) ///< ID of each participating thread + : stride_(byte_stride) { #ifndef NDEBUG - bool assert_fail = false; - if (offset_k != 0 && offset_k != 32) { - assert_fail = true; - if (lane_id == 0) { - printf("Invalid offset_k: %d!\n", offset_k); - } - } - if (SmemDimM % 16 != 0) { - // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim - // and don't want to deal with left over M - assert_fail = true; - if (lane_id == 0) { - printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); - } - } - assert(assert_fail == false); + bool assertion_pass = true; + if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); + } + } + if (byte_stride % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); + } + } + if (k_start % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_start: %d is not aligned to 16B boundary!\n", k_start); + } + } + if (k_end % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is not aligned to 16B boundary!\n", k_end); + } + } + if (mn_end <= mn_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); + } + } + if (k_end <= k_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); + } + } + if (lane_id < 0 || lane_id >= kThreads) { + assertion_pass = false; + if (lane_id == 0) { + printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); + } + } + assert(assertion_pass); #endif - constexpr int kStrideM = 16 / kSwizzleM; // Span 2 swizzle patterns on M dim - int m_lane_id = mod_power2<16>(lane_id); - int k_lane_id = (lane_id >> 4) + (offset_k >> 4); + int lane_m = div_power2(lane_id); + int lane_k = mod_power2(lane_id); + mn_start += lane_m; + k_start += mul_power2(lane_k); + + mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); + k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); + if (mn_cnt_ <= 0 || k_cnt_ <= 0) { + mn_cnt_ = 0; + k_cnt_ = 0; + g_ptr_ = nullptr; + return; + } + g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; + // if (lane_id == 0) + // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); + } + + /** + * @brief Load a row major tile (SmemDimM, 64) from global memory to shared memory + */ + CUTLASS_DEVICE + void load_to_smem(const int lane_id, void* smem) { + // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size + static_assert(kGmemLoadStrideM == kSwizzleM); + const uint8_t* data_ptr = g_ptr_; + uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(Swizzled64{}(lane_id)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SmemDimM / kSwizzleM; ++i) { + cutlass::arch::cp_async( + smem_ptr, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); + data_ptr += mul_power2(stride_); + smem_ptr += kSwizzleTileSize * kLoadVectorSize; + } + } + + CUTLASS_DEVICE + void load_to_smem_split(const int lane_id, void* smem, const int split_idx) { + // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size + static_assert(kGmemLoadStrideM == kSwizzleM); + + const uint8_t* split_ptr = g_ptr_ + mul_power2(split_idx * stride_); + uint8_t* split_smem_ptr = reinterpret_cast(smem) + mul_power2(Swizzled64{}(lane_id)) + split_idx * kSwizzleTileSize * kLoadVectorSize; + + cutlass::arch::cp_async( + split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); + } + + /** + * @brief Advance global memory pointer to the next tile in the K dimension + */ + CUTLASS_DEVICE + SwizzleTileLoader& operator++() { + --k_cnt_; + if (k_cnt_ > 0) { + g_ptr_ += kLoadVectorSize * kSwizzleK; + } else { + g_ptr_ = nullptr; + } + return *this; + } + + /** + * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, + * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) + */ + CUTLASS_DEVICE + void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { +#ifndef NDEBUG + bool assert_fail = false; + if (offset_k != 0 && offset_k != 32) { + assert_fail = true; + if (lane_id == 0) { + printf("Invalid offset_k: %d!\n", offset_k); + } + } + if (SmemDimM % 16 != 0) { + // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim + // and don't want to deal with left over M + assert_fail = true; + if (lane_id == 0) { + printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); + } + } + assert(assert_fail == false); +#endif - int m_tile_id = div_power2(m_lane_id); - int m_tile_offset = mod_power2(m_lane_id); - int swizzled_id = Swizzled64{}(k_lane_id, m_tile_offset) + mul_power2(m_tile_id); - // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); - const uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(swizzled_id); + constexpr int kStrideM = 16 / kSwizzleM; // Span 2 swizzle patterns on M dim + int m_lane_id = mod_power2<16>(lane_id); + int k_lane_id = (lane_id >> 4) + (offset_k >> 4); - using FragType = cutlass::Array; - FragType* frag_ptr = reinterpret_cast(frag); + int m_tile_id = div_power2(m_lane_id); + int m_tile_offset = mod_power2(m_lane_id); + int swizzled_id = Swizzled64{}(k_lane_id, m_tile_offset) + mul_power2(m_tile_id); + // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); + const uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(swizzled_id); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < (SmemDimM / 16); ++i) { - // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); - cutlass::arch::ldsm(frag_ptr[i], smem_ptr); - smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; - } - } + using FragType = cutlass::Array; + FragType* frag_ptr = reinterpret_cast(frag); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (SmemDimM / 16); ++i) { + // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); + cutlass::arch::ldsm(frag_ptr[i], smem_ptr); + smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; + } + } - CUTLASS_DEVICE - void load_fragment_k64(const int lane_id, void const* smem, int offset_k, void* frag) { + CUTLASS_DEVICE + void load_fragment_k64(const int lane_id, void const* smem, int offset_k, void* frag) { #ifndef NDEBUG - // Here we use a single warp to load 4 tiles on the k dimension. - // This is only useful in loading packed B tensor where a 2x2 int4 - // tile structure is disguised as a single fp16 tile. So 4 such - // tiles, when dequantized, become 4 set of 2x2 fp16 tiles. Each - // of the 2x2 fp16 tiles can participate in two 16x8x16 tensor core - // operations. - bool assert_fail = false; - if (SmemDimM != 8) { - assert_fail = true; - if (lane_id == 0) { - printf("Special case for SmemDimM = 8 but found %d!\n", SmemDimM); - } - } - if (offset_k != 0) { - assert_fail = true; - if (lane_id == 0) { - printf("Special case for offset_k = 0 but found %d!\n", offset_k); - } - } - assert(assert_fail == false); + // Here we use a single warp to load 4 tiles on the k dimension. + // This is only useful in loading packed B tensor where a 2x2 int4 + // tile structure is disguised as a single fp16 tile. So 4 such + // tiles, when dequantized, become 4 set of 2x2 fp16 tiles. Each + // of the 2x2 fp16 tiles can participate in two 16x8x16 tensor core + // operations. + bool assert_fail = false; + if (SmemDimM != 8) { + assert_fail = true; + if (lane_id == 0) { + printf("Special case for SmemDimM = 8 but found %d!\n", SmemDimM); + } + } + if (offset_k != 0) { + assert_fail = true; + if (lane_id == 0) { + printf("Special case for offset_k = 0 but found %d!\n", offset_k); + } + } + assert(assert_fail == false); #endif - // 1x4 tiles per load: 8 threads on the M dim and 4 on the K dim - int m_lane_id = mod_power2<8>(lane_id); - int k_lane_id = div_power2<8>(lane_id); + // 1x4 tiles per load: 8 threads on the M dim and 4 on the K dim + int m_lane_id = mod_power2<8>(lane_id); + int k_lane_id = div_power2<8>(lane_id); - int swizzled_id = Swizzled64{}(k_lane_id, m_lane_id); - // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); - const uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(swizzled_id); - - using FragType = cutlass::Array; - FragType* frag_ptr = reinterpret_cast(frag); - cutlass::arch::ldsm(frag_ptr[0], smem_ptr); - } + int swizzled_id = Swizzled64{}(k_lane_id, m_lane_id); + // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); + const uint8_t* smem_ptr = reinterpret_cast(smem) + mul_power2(swizzled_id); + using FragType = cutlass::Array; + FragType* frag_ptr = reinterpret_cast(frag); + cutlass::arch::ldsm(frag_ptr[0], smem_ptr); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template class SwizzleTileLoader { - public: - static constexpr int SmemDimM = SmemDimM_; - static constexpr int SmemDimK = 128; - static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes - static constexpr int kBlockSize = SmemDimM * SmemDimK; - static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); - - // Swizzle pattern is 8x8 - static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; - static_assert(kSwizzleK == cute::_8::value); - static constexpr int kSwizzleM = cute::_8::value; - static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; - using Swizzled128 = decltype( - cute::composition(cute::Swizzle<3,0,3>{}, - cute::Layout, - cute::Stride>{})); - - static constexpr int kThreads = 32; - static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; - static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); - - // During pipelined MMA, each stage (processing a tile) is split - // into multiple mma iterations. We need to somehow split the - // the loading of global memory tile into multiple calls, doing - // our best to help spread these actions across different iterations - // in a stage. - static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; + public: + static constexpr int SmemDimM = SmemDimM_; + static constexpr int SmemDimK = 128; + static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes + static constexpr int kBlockSize = SmemDimM * SmemDimK; + static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); + + // Swizzle pattern is 8x8 + static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; + static_assert(kSwizzleK == cute::_8::value); + static constexpr int kSwizzleM = cute::_8::value; + static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; + using Swizzled128 = decltype(cute::composition(cute::Swizzle<3, 0, 3>{}, + cute::Layout, + cute::Stride>{})); + + static constexpr int kThreads = 32; + static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; + static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); + + // During pipelined MMA, each stage (processing a tile) is split + // into multiple mma iterations. We need to somehow split the + // the loading of global memory tile into multiple calls, doing + // our best to help spread these actions across different iterations + // in a stage. + static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; private: - /// Pointer to global memory to load data from - uint8_t const* g_ptr_{nullptr}; - /// Iteration boundaries in the M or N dimension - int mn_cnt_{0}; - /// Iteration boundaries in the K dimension, in strides of 16 - int k_cnt_{0}; - /// Stride in bytes to advance to next row in m or n dimension - const int stride_; + /// Pointer to global memory to load data from + uint8_t const* g_ptr_{nullptr}; + /// Iteration boundaries in the M or N dimension + int mn_cnt_{0}; + /// Iteration boundaries in the K dimension, in strides of 16 + int k_cnt_{0}; + /// Stride in bytes to advance to next row in m or n dimension + const int stride_; public: - CUTLASS_DEVICE - SwizzleTileLoader( - void const* data_ptr, ///< Pointer to the global memory tiles - int byte_stride, ///< Stride in bytes to advance to next row - int mn_start, ///< Starting position in the M or N dimension - int mn_end, ///< End position in the M or N dimension - int k_start, ///< Starting position in the K dimension - int k_end, ///< End position in the K dimension - int lane_id) ///< ID of each participating thread - : stride_(byte_stride) { - #ifndef NDEBUG - bool assertion_pass = true; - if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); - } - } - if (byte_stride % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); - } - } - if (k_start % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_start: %d is not aligned to 16B boundary!\n", k_start); - } - } - if (k_end % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is not aligned to 16B boundary!\n", k_end); - } - } - if (mn_end <= mn_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); - } - } - if (k_end <= k_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); - } - } - if (lane_id < 0 || lane_id >= kThreads) { - assertion_pass = false; - if (lane_id == 0) { - printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); - } - } - assert(assertion_pass); - #endif - - int lane_m = lane_id / kSwizzleK; - int lane_k = lane_id % kSwizzleK; - mn_start += lane_m; - k_start += lane_k * kLoadVectorSize; - - mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); - k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); - if (mn_cnt_ <= 0 || k_cnt_ <= 0) { - mn_cnt_ = 0; - k_cnt_ = 0; - g_ptr_ = nullptr; - return; - } - g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; - // if (lane_id == 0) - // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); - } - - /** - * @brief Load a row major tile (SmemDimM, 128) from global memory to shared memory - */ - CUTLASS_DEVICE - void load_to_smem(const int lane_id, void* smem) { - const uint8_t* data_ptr = g_ptr_; - - // The swizzle pattern is 8x8, but we only have 32 threads, - // covering half of the swizzle pattern - static_assert(kGmemLoadStrideM * 2 == kSwizzleM); - uint8_t* smem_ptr0 = reinterpret_cast(smem) + Swizzled128{}(lane_id) * kLoadVectorSize; - uint8_t* smem_ptr1 = reinterpret_cast(smem) + Swizzled128{}(lane_id + kThreads) * kLoadVectorSize; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SmemDimM / kGmemLoadStrideM;) { - cutlass::arch::cp_async( - smem_ptr0, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); - data_ptr += stride_ * kGmemLoadStrideM; - smem_ptr0 += kSwizzleTileSize * kLoadVectorSize; - ++i; - - cutlass::arch::cp_async( - smem_ptr1, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); - data_ptr += stride_ * kGmemLoadStrideM; - smem_ptr1 += kSwizzleTileSize * kLoadVectorSize; - ++i; - } - } - - CUTLASS_DEVICE - void load_to_smem_split(const int lane_id, void* smem, const int split_idx){ - const uint8_t* split_ptr = g_ptr_ + split_idx * stride_ * kGmemLoadStrideM; - const int offset = (split_idx >> 1) * kSwizzleTileSize * kLoadVectorSize; - const int swizzled = Swizzled128{}(lane_id + (split_idx & 1) * kThreads) * kLoadVectorSize; - uint8_t* split_smem_ptr = reinterpret_cast(smem) + swizzled + offset; - - cutlass::arch::cp_async( - split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); - } - - /** - * @brief Advance global memory pointer to the next tile in the K dimension - */ - CUTLASS_DEVICE - SwizzleTileLoader& operator++() { - --k_cnt_; - if (k_cnt_ > 0) { - g_ptr_ += kLoadVectorSize * kSwizzleK; - } else { - g_ptr_ = nullptr; - } - return *this; - } - - /** - * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, - * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) - */ - CUTLASS_DEVICE - void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { + CUTLASS_DEVICE + SwizzleTileLoader( + void const* data_ptr, ///< Pointer to the global memory tiles + int byte_stride, ///< Stride in bytes to advance to next row + int mn_start, ///< Starting position in the M or N dimension + int mn_end, ///< End position in the M or N dimension + int k_start, ///< Starting position in the K dimension + int k_end, ///< End position in the K dimension + int lane_id) ///< ID of each participating thread + : stride_(byte_stride) { #ifndef NDEBUG - bool assert_fail = false; - if ((offset_k % 32) != 0) { - assert_fail = true; - if (lane_id == 0) { - printf("Invalid offset_k: %d!\n", offset_k); - } - } - if ((SmemDimM % 16) != 0) { - // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim - // and don't want to deal with left over M - assert_fail = true; - if (lane_id == 0) { - printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); - } - } - assert(assert_fail == false); + bool assertion_pass = true; + if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); + } + } + if (byte_stride % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); + } + } + if (k_start % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_start: %d is not aligned to 16B boundary!\n", k_start); + } + } + if (k_end % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is not aligned to 16B boundary!\n", k_end); + } + } + if (mn_end <= mn_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); + } + } + if (k_end <= k_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); + } + } + if (lane_id < 0 || lane_id >= kThreads) { + assertion_pass = false; + if (lane_id == 0) { + printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); + } + } + assert(assertion_pass); #endif - constexpr int kStrideM = 16 / kSwizzleM; // Span 2 swizzle patterns on M dim - int m_lane_id = lane_id % 16; - int k_lane_id = lane_id / 16 + offset_k / 16; + int lane_m = lane_id / kSwizzleK; + int lane_k = lane_id % kSwizzleK; + mn_start += lane_m; + k_start += lane_k * kLoadVectorSize; + + mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); + k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); + if (mn_cnt_ <= 0 || k_cnt_ <= 0) { + mn_cnt_ = 0; + k_cnt_ = 0; + g_ptr_ = nullptr; + return; + } + g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; + // if (lane_id == 0) + // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); + } + + /** + * @brief Load a row major tile (SmemDimM, 128) from global memory to shared memory + */ + CUTLASS_DEVICE + void load_to_smem(const int lane_id, void* smem) { + const uint8_t* data_ptr = g_ptr_; + + // The swizzle pattern is 8x8, but we only have 32 threads, + // covering half of the swizzle pattern + static_assert(kGmemLoadStrideM * 2 == kSwizzleM); + uint8_t* smem_ptr0 = reinterpret_cast(smem) + Swizzled128{}(lane_id)*kLoadVectorSize; + uint8_t* smem_ptr1 = reinterpret_cast(smem) + Swizzled128{}(lane_id + kThreads) * kLoadVectorSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SmemDimM / kGmemLoadStrideM;) { + cutlass::arch::cp_async( + smem_ptr0, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); + data_ptr += stride_ * kGmemLoadStrideM; + smem_ptr0 += kSwizzleTileSize * kLoadVectorSize; + ++i; + + cutlass::arch::cp_async( + smem_ptr1, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); + data_ptr += stride_ * kGmemLoadStrideM; + smem_ptr1 += kSwizzleTileSize * kLoadVectorSize; + ++i; + } + } + + CUTLASS_DEVICE + void load_to_smem_split(const int lane_id, void* smem, const int split_idx) { + const uint8_t* split_ptr = g_ptr_ + split_idx * stride_ * kGmemLoadStrideM; + const int offset = (split_idx >> 1) * kSwizzleTileSize * kLoadVectorSize; + const int swizzled = Swizzled128{}(lane_id + (split_idx & 1) * kThreads) * kLoadVectorSize; + uint8_t* split_smem_ptr = reinterpret_cast(smem) + swizzled + offset; + + cutlass::arch::cp_async( + split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); + } + + /** + * @brief Advance global memory pointer to the next tile in the K dimension + */ + CUTLASS_DEVICE + SwizzleTileLoader& operator++() { + --k_cnt_; + if (k_cnt_ > 0) { + g_ptr_ += kLoadVectorSize * kSwizzleK; + } else { + g_ptr_ = nullptr; + } + return *this; + } + + /** + * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, + * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) + */ + CUTLASS_DEVICE + void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { +#ifndef NDEBUG + bool assert_fail = false; + if ((offset_k % 32) != 0) { + assert_fail = true; + if (lane_id == 0) { + printf("Invalid offset_k: %d!\n", offset_k); + } + } + if ((SmemDimM % 16) != 0) { + // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim + // and don't want to deal with left over M + assert_fail = true; + if (lane_id == 0) { + printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); + } + } + assert(assert_fail == false); +#endif + + constexpr int kStrideM = 16 / kSwizzleM; // Span 2 swizzle patterns on M dim + int m_lane_id = lane_id % 16; + int k_lane_id = lane_id / 16 + offset_k / 16; - int m_tile_id = m_lane_id / kSwizzleM; - int m_tile_offset = m_lane_id % kSwizzleM; - int swizzled_id = Swizzled128{}(k_lane_id, m_tile_offset) + m_tile_id * kSwizzleTileSize; - // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); - const uint8_t* smem_ptr = reinterpret_cast(smem) + swizzled_id * kLoadVectorSize; + int m_tile_id = m_lane_id / kSwizzleM; + int m_tile_offset = m_lane_id % kSwizzleM; + int swizzled_id = Swizzled128{}(k_lane_id, m_tile_offset) + m_tile_id * kSwizzleTileSize; + // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); + const uint8_t* smem_ptr = reinterpret_cast(smem) + swizzled_id * kLoadVectorSize; - using FragType = cutlass::Array; - FragType* frag_ptr = reinterpret_cast(frag); + using FragType = cutlass::Array; + FragType* frag_ptr = reinterpret_cast(frag); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SmemDimM / 16; ++i) { - // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); - cutlass::arch::ldsm(frag_ptr[i], smem_ptr); - smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; - } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SmemDimM / 16; ++i) { + // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); + cutlass::arch::ldsm(frag_ptr[i], smem_ptr); + smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template class SwizzleTileLoader { - public: - static constexpr int SmemDimM = SmemDimM_; - static constexpr int SmemDimK = 32; - static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes - static constexpr int kBlockSize = SmemDimM * SmemDimK; - static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); - - // Swizzle pattern is 2x16 - static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; - static_assert(kSwizzleK == cute::_2::value); - static constexpr int kSwizzleM = cute::_16::value; - static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; - using Swizzled32 = decltype( - cute::composition(cute::Swizzle<1,0,3>{}, - cute::Layout, - cute::Stride>{})); - - static constexpr int kThreads = 32; - static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; - static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); - - // During pipelined MMA, each stage (processing a tile) is split - // into multiple mma iterations. We need to somehow split the - // the loading of global memory tile into multiple calls, doing - // our best to help spread these actions across different iterations - // in a stage. - static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; - - private: - /// Pointer to global memory to load data from - uint8_t const* g_ptr_{nullptr}; - /// Iteration boundaries in the M or N dimension - int mn_cnt_{0}; - /// Iteration boundaries in the K dimension, in strides of 16 - int k_cnt_{0}; - /// Stride in bytes to advance to next row in m or n dimension - const int stride_; - - public: - CUTLASS_DEVICE - SwizzleTileLoader( - void const* data_ptr, ///< Pointer to the global memory tiles - int byte_stride, ///< Stride in bytes to advance to next row - int mn_start, ///< Starting position in the M or N dimension - int mn_end, ///< End position in the M or N dimension - int k_start, ///< Starting position in the K dimension - int k_end, ///< End position in the K dimension - int lane_id) ///< ID of each participating thread - : stride_(byte_stride) { - #ifndef NDEBUG - bool assertion_pass = true; - if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); - } - } - if (byte_stride % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); - } - } - if (k_start % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_start: %d is not aligned to 16B boundary!\n", k_start); - } - } - if (k_end % kLoadVectorSize != 0) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is not aligned to 16B boundary!\n", k_end); - } - } - if (mn_end <= mn_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); - } - } - if (k_end <= k_start) { - assertion_pass = false; - if (lane_id == 0) { - printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); - } - } - if (lane_id < 0 || lane_id >= kThreads) { - assertion_pass = false; - if (lane_id == 0) { - printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); - } - } - assert(assertion_pass); - #endif - - int lane_m = lane_id / kSwizzleK; - int lane_k = lane_id % kSwizzleK; - mn_start += lane_m; - k_start += lane_k * kLoadVectorSize; - - mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); - k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); - if (mn_cnt_ <= 0 || k_cnt_ <= 0) { - mn_cnt_ = 0; - k_cnt_ = 0; - g_ptr_ = nullptr; - return; - } - g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; - // if (lane_id == 0) - // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); - } - - /** - * @brief Load a row major tile (SmemDimM, 32) from global memory to shared memory - */ - CUTLASS_DEVICE - void load_to_smem(const int lane_id, void* smem) { - // The swizzle pattern is 2x16, same size as kThreads - static_assert(kGmemLoadStrideM == kSwizzleM); - const uint8_t* data_ptr = g_ptr_; - uint8_t* smem_ptr = reinterpret_cast(smem) + Swizzled32{}(lane_id) * kLoadVectorSize; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SmemDimM / kSwizzleM; ++i) { - cutlass::arch::cp_async( - smem_ptr, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); - data_ptr += stride_ * kGmemLoadStrideM; - smem_ptr += kSwizzleTileSize * kLoadVectorSize; - } - } - - CUTLASS_DEVICE - void load_to_smem_split(const int lane_id, void* smem, const int split_idx){ - // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size - static_assert(kGmemLoadStrideM == kSwizzleM); - - const uint8_t* split_ptr = g_ptr_ + split_idx * stride_ * kGmemLoadStrideM; - uint8_t* split_smem_ptr = reinterpret_cast(smem) + Swizzled32{}(lane_id) * kLoadVectorSize + split_idx * kSwizzleTileSize * kLoadVectorSize; - - cutlass::arch::cp_async( - split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); - } - - /** - * @brief Advance global memory pointer to the next tile in the K dimension - */ - CUTLASS_DEVICE - SwizzleTileLoader& operator++() { - --k_cnt_; - if (k_cnt_ > 0) { - g_ptr_ += kLoadVectorSize * kSwizzleK; - } else { - g_ptr_ = nullptr; - } - return *this; - } - - /** - * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, - * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) - */ - CUTLASS_DEVICE - void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { + public: + static constexpr int SmemDimM = SmemDimM_; + static constexpr int SmemDimK = 32; + static constexpr int kLoadVectorSize = 16; // one cp.async loads 16 bytes + static constexpr int kBlockSize = SmemDimM * SmemDimK; + static constexpr int kTiles = (SmemDimM / 8) * (SmemDimK / 16); + + // Swizzle pattern is 2x16 + static constexpr int kSwizzleK = SmemDimK / kLoadVectorSize; + static_assert(kSwizzleK == cute::_2::value); + static constexpr int kSwizzleM = cute::_16::value; + static constexpr int kSwizzleTileSize = kSwizzleK * kSwizzleM; + using Swizzled32 = decltype(cute::composition(cute::Swizzle<1, 0, 3>{}, + cute::Layout, + cute::Stride>{})); + + static constexpr int kThreads = 32; + static constexpr int kGmemLoadStrideM = kThreads / kSwizzleK; + static_assert(kGmemLoadStrideM * kSwizzleK == kThreads); + + // During pipelined MMA, each stage (processing a tile) is split + // into multiple mma iterations. We need to somehow split the + // the loading of global memory tile into multiple calls, doing + // our best to help spread these actions across different iterations + // in a stage. + static constexpr int kGloadSplit = SmemDimM / kGmemLoadStrideM; + + private: + /// Pointer to global memory to load data from + uint8_t const* g_ptr_{nullptr}; + /// Iteration boundaries in the M or N dimension + int mn_cnt_{0}; + /// Iteration boundaries in the K dimension, in strides of 16 + int k_cnt_{0}; + /// Stride in bytes to advance to next row in m or n dimension + const int stride_; + + public: + CUTLASS_DEVICE + SwizzleTileLoader( + void const* data_ptr, ///< Pointer to the global memory tiles + int byte_stride, ///< Stride in bytes to advance to next row + int mn_start, ///< Starting position in the M or N dimension + int mn_end, ///< End position in the M or N dimension + int k_start, ///< Starting position in the K dimension + int k_end, ///< End position in the K dimension + int lane_id) ///< ID of each participating thread + : stride_(byte_stride) { #ifndef NDEBUG - bool assert_fail = false; - if (offset_k != 0) { - assert_fail = true; - if (lane_id == 0) { - printf("Invalid offset_k: %d!\n", offset_k); - } - } - if ((SmemDimM % 16) != 0) { - // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim - // and don't want to deal with left over M - assert_fail = true; - if (lane_id == 0) { - printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); - } - } - assert(assert_fail == false); + bool assertion_pass = true; + if (reinterpret_cast(data_ptr) % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("data_ptr: %p is not aligned to 16B boundary!\n", data_ptr); + } + } + if (byte_stride % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("byte_stride: %d is not aligned to 16B boundary!\n", byte_stride); + } + } + if (k_start % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_start: %d is not aligned to 16B boundary!\n", k_start); + } + } + if (k_end % kLoadVectorSize != 0) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is not aligned to 16B boundary!\n", k_end); + } + } + if (mn_end <= mn_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("mn_end: %d is less than or equal to mn_start: %d!\n", mn_end, mn_start); + } + } + if (k_end <= k_start) { + assertion_pass = false; + if (lane_id == 0) { + printf("k_end: %d is less than or equal to k_start: %d!\n", k_end, k_start); + } + } + if (lane_id < 0 || lane_id >= kThreads) { + assertion_pass = false; + if (lane_id == 0) { + printf("Warp based loader, lane_id should be [0-32) but it is: %d!\n", lane_id); + } + } + assert(assertion_pass); +#endif + + int lane_m = lane_id / kSwizzleK; + int lane_k = lane_id % kSwizzleK; + mn_start += lane_m; + k_start += lane_k * kLoadVectorSize; + + mn_cnt_ = div_up(mn_end - mn_start, kGmemLoadStrideM); + k_cnt_ = div_up(k_end - k_start, kSwizzleK * kLoadVectorSize); + if (mn_cnt_ <= 0 || k_cnt_ <= 0) { + mn_cnt_ = 0; + k_cnt_ = 0; + g_ptr_ = nullptr; + return; + } + g_ptr_ = reinterpret_cast(data_ptr) + mn_start * byte_stride + k_start; + // if (lane_id == 0) + // printf("lane_id: %d, mn_start: %d, mn_end: %d, k_start: %d, k_end: %d, g_ptr: %p\n", lane_id, mn_start, mn_end, k_start, k_end, g_ptr_); + } + + /** + * @brief Load a row major tile (SmemDimM, 32) from global memory to shared memory + */ + CUTLASS_DEVICE + void load_to_smem(const int lane_id, void* smem) { + // The swizzle pattern is 2x16, same size as kThreads + static_assert(kGmemLoadStrideM == kSwizzleM); + const uint8_t* data_ptr = g_ptr_; + uint8_t* smem_ptr = reinterpret_cast(smem) + Swizzled32{}(lane_id)*kLoadVectorSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SmemDimM / kSwizzleM; ++i) { + cutlass::arch::cp_async( + smem_ptr, data_ptr, g_ptr_ != nullptr && i < mn_cnt_); + data_ptr += stride_ * kGmemLoadStrideM; + smem_ptr += kSwizzleTileSize * kLoadVectorSize; + } + } + + CUTLASS_DEVICE + void load_to_smem_split(const int lane_id, void* smem, const int split_idx) { + // Here we rely on the fact that kThreads is 32, same as the swizzle pattern size + static_assert(kGmemLoadStrideM == kSwizzleM); + + const uint8_t* split_ptr = g_ptr_ + split_idx * stride_ * kGmemLoadStrideM; + uint8_t* split_smem_ptr = reinterpret_cast(smem) + Swizzled32{}(lane_id)*kLoadVectorSize + split_idx * kSwizzleTileSize * kLoadVectorSize; + + cutlass::arch::cp_async( + split_smem_ptr, split_ptr, g_ptr_ != nullptr && split_idx < mn_cnt_); + } + + /** + * @brief Advance global memory pointer to the next tile in the K dimension + */ + CUTLASS_DEVICE + SwizzleTileLoader& operator++() { + --k_cnt_; + if (k_cnt_ > 0) { + g_ptr_ += kLoadVectorSize * kSwizzleK; + } else { + g_ptr_ = nullptr; + } + return *this; + } + + /** + * @brief Load a ribbin of (SmemDimM, 32) from shared memory to fragment, + * fitting fp16 gemm sm80 tensor core shape, where k = 16 x sizeof(fp16) + */ + CUTLASS_DEVICE + void load_fragment_k32(const int lane_id, void const* smem, int offset_k, void* frag) { +#ifndef NDEBUG + bool assert_fail = false; + if (offset_k != 0) { + assert_fail = true; + if (lane_id == 0) { + printf("Invalid offset_k: %d!\n", offset_k); + } + } + if ((SmemDimM % 16) != 0) { + // 2x2 tiles per load: 16 threads on the M dim and 2 on the K dim + // and don't want to deal with left over M + assert_fail = true; + if (lane_id == 0) { + printf("SmemDimM: %d two small, cannot use ldmatrix fully!\n", SmemDimM); + } + } + assert(assert_fail == false); #endif - constexpr int kStrideM = 16 / kSwizzleM; // Span 1 swizzle patterns on M dim - int m_lane_id = lane_id % 16; - int k_lane_id = lane_id / 16; // 0 or 1 + constexpr int kStrideM = 16 / kSwizzleM; // Span 1 swizzle patterns on M dim + int m_lane_id = lane_id % 16; + int k_lane_id = lane_id / 16; // 0 or 1 - int swizzled_id = Swizzled32{}(k_lane_id, m_lane_id); - // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); - const uint8_t* smem_ptr = reinterpret_cast(smem) + swizzled_id * kLoadVectorSize; + int swizzled_id = Swizzled32{}(k_lane_id, m_lane_id); + // printf("lane_id: %d, m_lane_id: %d, k_lane_id: %d, swizzled_id: %d\n", lane_id, m_lane_id, k_lane_id, swizzled_id); + const uint8_t* smem_ptr = reinterpret_cast(smem) + swizzled_id * kLoadVectorSize; - using FragType = cutlass::Array; - FragType* frag_ptr = reinterpret_cast(frag); + using FragType = cutlass::Array; + FragType* frag_ptr = reinterpret_cast(frag); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SmemDimM / 16; ++i) { - // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); - cutlass::arch::ldsm(frag_ptr[i], smem_ptr); - smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; - } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SmemDimM / 16; ++i) { + // printf("lane_id: %d, load %d, val: %d, smem_ptr: %p\n", lane_id, i, smem_ptr[0], smem_ptr); + cutlass::arch::ldsm(frag_ptr[i], smem_ptr); + smem_ptr += kSwizzleTileSize * kStrideM * kLoadVectorSize; } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace mickey +} // namespace warp +} // namespace gemm +} // namespace mickey diff --git a/onnxruntime/core/mickey/gemm/warp/tensor_core_tile_loader.h b/onnxruntime/core/mickey/gemm/warp/tensor_core_tile_loader.h index baab1d160f435..bbbde02158846 100644 --- a/onnxruntime/core/mickey/gemm/warp/tensor_core_tile_loader.h +++ b/onnxruntime/core/mickey/gemm/warp/tensor_core_tile_loader.h @@ -32,47 +32,47 @@ template < /// Number of tiles in the K dimension int KTiles> class TensorCoreTileLoader { - public: - // Number of tiles must be loaded from global memory to shared memory with a single ld.async - // instruction by each thread in the warp, and a single ldmatrix instruction by the warp. - static constexpr int kMNTiles = MNTiles; - static constexpr int kKTiles = KTiles; - static constexpr int kTiles = kMNTiles * kKTiles; - static_assert(kTiles == 1 || kTiles == 2 || kTiles == 4, "Number of tiles must be 1, 2 or 4"); + public: + // Number of tiles must be loaded from global memory to shared memory with a single ld.async + // instruction by each thread in the warp, and a single ldmatrix instruction by the warp. + static constexpr int kMNTiles = MNTiles; + static constexpr int kKTiles = KTiles; + static constexpr int kTiles = kMNTiles * kKTiles; + static_assert(kTiles == 1 || kTiles == 2 || kTiles == 4, "Number of tiles must be 1, 2 or 4"); - static constexpr int kMNThreads = kMNTiles * 8; - static constexpr int kKThreads = kKTiles; - static constexpr int kThreads = kMNThreads * kKThreads; + static constexpr int kMNThreads = kMNTiles * 8; + static constexpr int kKThreads = kKTiles; + static constexpr int kThreads = kMNThreads * kKThreads; - /// Each tensor core tile is 16x8 in size - static constexpr int kMNStride = kMNTiles * 8; - static constexpr int kKStride = kKTiles * 16; - static constexpr int kByteSize = kTiles * 16 * 8; + /// Each tensor core tile is 16x8 in size + static constexpr int kMNStride = kMNTiles * 8; + static constexpr int kKStride = kKTiles * 16; + static constexpr int kByteSize = kTiles * 16 * 8; private: - /// Pointer to global memory to load data from - uint8_t const* g_ptr_{nullptr}; - /// Iteration boundaries in the M or N dimension - int mn_cnt_{0}; - /// Iteration boundaries in the K dimension, in strides of 16 - int k16_cnt_{0}; - /// Stride in bytes to advance to next row in m or n dimension - const int stride_; - /// thread id in a warp - const int lane_id_; - + /// Pointer to global memory to load data from + uint8_t const* g_ptr_{nullptr}; + /// Iteration boundaries in the M or N dimension + int mn_cnt_{0}; + /// Iteration boundaries in the K dimension, in strides of 16 + int k16_cnt_{0}; + /// Stride in bytes to advance to next row in m or n dimension + const int stride_; + /// thread id in a warp + const int lane_id_; + public: /// Construct a TileIterator with zero threadblock offset CUTLASS_HOST_DEVICE TensorCoreTileLoader( - void const* data_ptr, ///< Pointer to the global memory tiles - int byte_stride, ///< Stride in bytes to advance to next row - int mn_start, ///< Starting position in the M or N dimension - int mn_end, ///< End position in the M or N dimension - int k_start, ///< Starting position in the K dimension - int k_end, ///< End position in the K dimension - int lane_id) ///< ID of each participating thread - : stride_(byte_stride), lane_id_(lane_id){ + void const* data_ptr, ///< Pointer to the global memory tiles + int byte_stride, ///< Stride in bytes to advance to next row + int mn_start, ///< Starting position in the M or N dimension + int mn_end, ///< End position in the M or N dimension + int k_start, ///< Starting position in the K dimension + int k_end, ///< End position in the K dimension + int lane_id) ///< ID of each participating thread + : stride_(byte_stride), lane_id_(lane_id) { #ifndef NDEBUG bool assertion_pass = true; if (reinterpret_cast(data_ptr) % 16 != 0) { @@ -172,10 +172,9 @@ class TensorCoreTileLoader { /** * @brief Get the pointer to the shared memory location for the current lane * @param smem_ptr pointer to the shared memory location for the warp. - */ - template - CUTLASS_DEVICE - T* get_smem_lane_ptr(T* smem_ptr) const { + */ + template + CUTLASS_DEVICE T* get_smem_lane_ptr(T* smem_ptr) const { if constexpr (kThreads < 32) { static_assert(kThreads & (kThreads - 1) == 0, "kThreads must be power of 2"); return reinterpret_cast(reinterpret_cast(smem_ptr) + ((lane_id_ & (kThreads - 1)) << 4)); @@ -184,9 +183,8 @@ class TensorCoreTileLoader { } } - template - CUTLASS_DEVICE - T* get_smem_warp_base_ptr(T* smem_lane_ptr) const { + template + CUTLASS_DEVICE T* get_smem_warp_base_ptr(T* smem_lane_ptr) const { if constexpr (kThreads < 32) { static_assert(kThreads & (kThreads - 1) == 0, "kThreads must be power of 2"); return reinterpret_cast(reinterpret_cast(smem_lane_ptr) - ((lane_id_ & (kThreads - 1)) << 4)); @@ -195,7 +193,6 @@ class TensorCoreTileLoader { } } - /// Loads a tile from global memory to shared memory CUTLASS_DEVICE void load_to(void* smem_lane_ptr) const { @@ -227,7 +224,7 @@ class TensorCoreTileLoader { if (g_ptr_ == nullptr) { return *this; } - + k16_cnt_ -= kKTiles; if (k16_cnt_ > 0) { g_ptr_ += 16 * kKTiles; @@ -237,9 +234,8 @@ class TensorCoreTileLoader { return *this; } - template - CUTLASS_DEVICE - void load_lateral_n(void* smem_lane_ptr) const { + template + CUTLASS_DEVICE void load_lateral_n(void* smem_lane_ptr) const { uint8_t* smem_bytes = reinterpret_cast(smem_lane_ptr); this->load_to(smem_bytes); smem_bytes += kByteSize; @@ -255,9 +251,8 @@ class TensorCoreTileLoader { cutlass::arch::ldsm(frag, smem_lane_ptr); } - template - CUTLASS_DEVICE - static void multi_ldmatrix_sync(cutlass::Array& fragment, T2 const* &smem_lane_ptr) { + template + CUTLASS_DEVICE static void multi_ldmatrix_sync(cutlass::Array& fragment, T2 const*& smem_lane_ptr) { static_assert(sizeof(unsigned) * kTiles * Loads == sizeof(T1) * Size, "Fragment size mismatch"); cutlass::Array* ptr = reinterpret_cast*>(fragment.data()); @@ -268,7 +263,6 @@ class TensorCoreTileLoader { smem_lane_ptr += kByteSize / sizeof(T2); } } - }; } // namespace warp diff --git a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm index a42d7ff8730cb..9c334bea2f468 100644 --- a/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/ConvSymKernelAvx2.asm @@ -23,6 +23,87 @@ INCLUDE ConvSymKernelCommon.inc INCLUDE AssembleAvxVnni.inc .list +extern CheckSaturationForVPMADDUBSW:proc + +CheckSaturation MACRO VecReg1Num, VecReg2Num + +; +; Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11). no RSI, RDI. +; + + push_reg rax + push_reg rcx + push_reg rdx + push_reg r8 + push_reg r9 + push_reg r10 + push_reg r11 + + sub rsp, 512 ; reserve space for 16 YMM registers (32 bytes) + +; +; Save YMM registers (YMM0 to YMM15). +; + + vmovdqu YMMWORD PTR [rsp], ymm0 + vmovdqu YMMWORD PTR [rsp+32], ymm1 + vmovdqu YMMWORD PTR [rsp+64], ymm2 + vmovdqu YMMWORD PTR [rsp+96], ymm3 + vmovdqu YMMWORD PTR [rsp+128], ymm4 + vmovdqu YMMWORD PTR [rsp+160], ymm5 + vmovdqu YMMWORD PTR [rsp+192], ymm6 + vmovdqu YMMWORD PTR [rsp+224], ymm7 + vmovdqu YMMWORD PTR [rsp+256], ymm8 + vmovdqu YMMWORD PTR [rsp+288], ymm9 + vmovdqu YMMWORD PTR [rsp+320], ymm10 + vmovdqu YMMWORD PTR [rsp+352], ymm11 + vmovdqu YMMWORD PTR [rsp+384], ymm12 + vmovdqu YMMWORD PTR [rsp+416], ymm13 + vmovdqu YMMWORD PTR [rsp+448], ymm14 + vmovdqu YMMWORD PTR [rsp+480], ymm15 + + lea rcx, [rsp+32*VecReg1Num] ; first operand (unsigned) + lea rdx, [rsp+32*VecReg2Num] ; second operand (signed) + + call CheckSaturationForVPMADDUBSW + +; +; Restore YMM registers. +; + + vmovdqu ymm0, YMMWORD PTR [rsp] + vmovdqu ymm1, YMMWORD PTR [rsp+32] + vmovdqu ymm2, YMMWORD PTR [rsp+64] + vmovdqu ymm3, YMMWORD PTR [rsp+96] + vmovdqu ymm4, YMMWORD PTR [rsp+128] + vmovdqu ymm5, YMMWORD PTR [rsp+160] + vmovdqu ymm6, YMMWORD PTR [rsp+192] + vmovdqu ymm7, YMMWORD PTR [rsp+224] + vmovdqu ymm8, YMMWORD PTR [rsp+256] + vmovdqu ymm9, YMMWORD PTR [rsp+288] + vmovdqu ymm10, YMMWORD PTR [rsp+320] + vmovdqu ymm11, YMMWORD PTR [rsp+352] + vmovdqu ymm12, YMMWORD PTR [rsp+384] + vmovdqu ymm13, YMMWORD PTR [rsp+416] + vmovdqu ymm14, YMMWORD PTR [rsp+448] + vmovdqu ymm15, YMMWORD PTR [rsp+480] + + add rsp, 512 ; clean up the reserved stack space + +; +; Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11), no RSI, RDI. +; + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdx + pop rcx + pop rax + + ENDM + ; ; Macro Description: ; @@ -50,9 +131,15 @@ INCLUDE AssembleAvxVnni.inc MultiplyAccumulateRowAvx2 MACRO Vec1Reg, Vec2Reg +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,0 +ENDIF vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd Vec1Reg,Vec1Reg,ymm3 +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,1 +ENDIF vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd Vec2Reg,Vec2Reg,ymm2 diff --git a/onnxruntime/core/mlas/lib/intrinsics/avx2/saturation_check_avx2.cpp b/onnxruntime/core/mlas/lib/intrinsics/avx2/saturation_check_avx2.cpp new file mode 100644 index 0000000000000..5ff4c0ee024d3 --- /dev/null +++ b/onnxruntime/core/mlas/lib/intrinsics/avx2/saturation_check_avx2.cpp @@ -0,0 +1,62 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check_avx2.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include + +#include +#include + +namespace onnxruntime +{ +extern std::atomic saturation_count; +} + +extern "C" void +CheckSaturationForVPMADDUBSW(const __m256i* unsigned_ptr, const __m256i* signed_ptr) +{ + // Load data from memory (unaligned load) + __m256i unsigned_data = _mm256_loadu_si256(unsigned_ptr); + __m256i signed_data = _mm256_loadu_si256(signed_ptr); + + alignas(32) uint8_t unsigned_bytes[32]; // Unsigned input values + alignas(32) int8_t signed_bytes[32]; // Signed input values + + // Store the data into the byte arrays + _mm256_store_si256(reinterpret_cast<__m256i*>(unsigned_bytes), unsigned_data); + _mm256_store_si256(reinterpret_cast<__m256i*>(signed_bytes), signed_data); + + bool saturation_detected = false; + + // Iterate through the 16 pairs of 8-bit unsigned and signed values + for (int i = 0; i < 16; ++i) { + // Perform the VPMADDUBSW operation in higher precision (int32_t) + int32_t computed_value = + static_cast(signed_bytes[2 * i]) * static_cast(static_cast(unsigned_bytes[2 * i])) + + static_cast(signed_bytes[2 * i + 1]) * static_cast(static_cast(unsigned_bytes[2 * i + 1])); + + // If the computed value exceeds the 16-bit signed integer range, saturation occurred + if (computed_value > INT16_MAX || computed_value < INT16_MIN) { + saturation_detected = true; + break; + } + } + + // If saturation is detected, log a warning (only log once based on the atomic count) + if (saturation_detected && ++onnxruntime::saturation_count < 2) { + std::cerr << "Warning: saturation detected in VPMADDUBSW instruction." << std::endl; + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4782e479753a2..184816ac24c43 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -18,6 +18,7 @@ Module Name: #pragma once #include +#include #include #include #include @@ -1144,6 +1145,10 @@ struct MLAS_PLATFORM { MLAS_PLATFORM(void); + // TODO: move to cpuinfo + bool Avx2Supported_ = false; + bool Avx512Supported_ = false; + #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 2165252ccd4cc..7724259e7c228 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -374,6 +374,8 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->Avx2Supported_ = true; + this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; @@ -465,6 +467,8 @@ Return Value: if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { + this->Avx512Supported_ = true; + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index eafe91575c528..19d11a60b7376 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -28,10 +28,11 @@ enum QNBitGemmVariant { // Valid variants - SQNBitGemmVariant_BitWidth4_CompFp32 = 0, - SQNBitGemmVariant_BitWidth4_CompInt8, - HQNBitGemmVariant_BitWidth4_CompFp16, - HQNBitGemmVariant_BitWidth4_CompInt8, + SQ4BitGemmVariant_CompFp32 = 0, + SQ4BitGemmVariant_CompInt8, + HQ4BitGemmVariant_CompFp16, + HQ4BitGemmVariant_CompInt8, + SQ8BitGemmVariant_CompInt8, // End of valid variants @@ -47,16 +48,21 @@ GetQNBitGemmVariant( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == SQNBIT_CompFp32) { - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == HQNBIT_CompFp16) { - return HQNBitGemmVariant_BitWidth4_CompFp16; - } else if (ComputeType == SQNBIT_CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } else if (ComputeType == HQNBIT_CompInt8) { - return HQNBitGemmVariant_BitWidth4_CompInt8; + if ((BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (BlkBitWidth == 4) { + if (ComputeType == SQNBIT_CompFp32) { + return SQ4BitGemmVariant_CompFp32; + } else if (ComputeType == HQNBIT_CompFp16) { + return HQ4BitGemmVariant_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { + return SQ4BitGemmVariant_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQ4BitGemmVariant_CompInt8; + } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8) { + return SQ8BitGemmVariant_CompInt8; + } } } @@ -80,21 +86,26 @@ MlasIsQNBitGemmAvailable( const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: { + case SQ4BitGemmVariant_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; } - case HQNBitGemmVariant_BitWidth4_CompFp16: { + case HQ4BitGemmVariant_CompFp16: { return Dispatch->HQ4BitGemmPackQuantBData != nullptr && Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + case SQ4BitGemmVariant_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return (Dispatch->SQ4BitGemmKernel_Packed_CompInt8 != nullptr && Dispatch->QuantizeA_Packed_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } + case SQ8BitGemmVariant_CompInt8: { + return Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr && + Dispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr && + Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr; + } default: { return false; } @@ -116,12 +127,12 @@ QNBitGemmPerGemmWorkspaceSize( ) { const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; - if (Dispatch == nullptr) { + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceSize == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); } return 0; @@ -135,12 +146,12 @@ QNBitGemmPerGemmWorkspaceAlignment( ) { const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; - if (Dispatch == nullptr) { + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceAlignment == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; @@ -208,6 +219,10 @@ MlasQNBitGemmPackQuantBDataSize( return Dispatch->Q4BitGemmPackQuantBDataSize( N, K, BlkLen, HasZeroPoint, ComputeType ); + } else if (BlkBitWidth == 8 && Dispatch->Q8BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q8BitGemmPackQuantBDataSize( + N, K, BlkLen, HasZeroPoint, ComputeType + ); } return 0; @@ -251,7 +266,7 @@ MlasQNBitGemmPackQuantBData( if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -289,6 +304,23 @@ MlasQNBitGemmPackQuantBData( ); return; } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + HasZeroPoint, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } } } @@ -675,6 +707,86 @@ SQ4BitGemm_CompInt8( } } +void +SQ8BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 8; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 16 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( + BlkLen, + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + } +} + template void InitializeWorkspace_CompInt8( @@ -802,7 +914,8 @@ InitializeWorkspaceFn GetInitializeWorkspace(QNBitGemmVariant variant) { switch (variant) { - case SQNBitGemmVariant_BitWidth4_CompInt8: + case SQ4BitGemmVariant_CompInt8: + case SQ8BitGemmVariant_CompInt8: return InitializeWorkspace_CompInt8; default: return nullptr; @@ -814,7 +927,7 @@ InitializeWorkspaceFn GetInitializeWorkspace(QNBitGemmVariant variant) { switch (variant) { - case HQNBitGemmVariant_BitWidth4_CompInt8: + case HQ4BitGemmVariant_CompInt8: return InitializeWorkspace_CompInt8; default: return nullptr; @@ -842,10 +955,12 @@ QNBitGemmFn GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: + case SQ4BitGemmVariant_CompFp32: return SQ4BitGemm_CompFp32; - case SQNBitGemmVariant_BitWidth4_CompInt8: + case SQ4BitGemmVariant_CompInt8: return SQ4BitGemm_CompInt8; + case SQ8BitGemmVariant_CompInt8: + return SQ8BitGemm_CompInt8; default: return nullptr; } @@ -856,7 +971,7 @@ QNBitGemmFn GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { - case HQNBitGemmVariant_BitWidth4_CompFp16: + case HQ4BitGemmVariant_CompFp16: return HQ4BitGemm_CompFp16; default: return nullptr; @@ -913,8 +1028,15 @@ MlasQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; @@ -984,8 +1106,16 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index e25455cbfa217..a740801e00514 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -46,13 +46,11 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } -template +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from Q4BitGemmPackQuantBDataSize - constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); #if defined(MLAS_TARGET_AMD64_IX86) @@ -104,6 +102,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH { Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q8BitGemmPackQuantBDataSize_Fn* Q8BitGemmPackQuantBDataSize = nullptr; + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, @@ -127,12 +136,27 @@ struct MLAS_QNBIT_GEMM_DISPATCH { const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ); SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + typedef void(SQ8BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ8BitGemmPackQuantBDataAndSumBlk_Fn* SQ8BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -148,7 +172,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, @@ -157,7 +181,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -165,12 +189,12 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; + QNBitGemmPerGemmWorkspaceAlignment_Fn* QNBitGemmPerGemmWorkspaceAlignment = nullptr; // // SQNBIT_CompFp32 kernel function prototypes. @@ -345,6 +369,45 @@ struct MLAS_QNBIT_GEMM_DISPATCH { SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index ab71492805e9c..0d06eb04e5245 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -155,7 +155,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte*, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -204,7 +204,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // size_t -Q4BitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, @@ -245,7 +245,7 @@ Q4BitGemmPerGemmWorkspaceSize( } size_t -Q4BitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) @@ -299,8 +299,8 @@ GetMlasQNBitGemmDispatchNeon( d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; - d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; diff --git a/onnxruntime/core/mlas/lib/quantize.cpp b/onnxruntime/core/mlas/lib/quantize.cpp index ae638fafee18f..fad174f747169 100644 --- a/onnxruntime/core/mlas/lib/quantize.cpp +++ b/onnxruntime/core/mlas/lib/quantize.cpp @@ -1704,8 +1704,8 @@ MlasRequantizeOutput( float min_f = float(std::numeric_limits::lowest() - ZeroPoint); float max_f = float(std::numeric_limits::max() - ZeroPoint); const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); - const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); - const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){min_f,min_f,min_f,min_f}); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){max_f,max_f,max_f,max_f}); const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); if (nullptr != Bias) { diff --git a/onnxruntime/core/mlas/lib/saturation_check.cpp b/onnxruntime/core/mlas/lib/saturation_check.cpp new file mode 100644 index 0000000000000..7b022a7563c70 --- /dev/null +++ b/onnxruntime/core/mlas/lib/saturation_check.cpp @@ -0,0 +1,42 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include "mlasi.h" + +namespace onnxruntime +{ + +#if defined(MLAS_TARGET_AMD64) + +std::atomic saturation_count{0}; + +void +reset_saturation_count() +{ + saturation_count = 0; +} + +#else + +void +reset_saturation_count() +{ +} + +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 79893eea85eca..384b04c807195 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -584,6 +584,88 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2( return CountM; } +template +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + size_t SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( const size_t BlkLen, @@ -1311,7 +1393,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -1328,23 +1410,52 @@ SQ4BitGemmPackQuantBDataAndBlkSum( HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } +static void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + // // Kernel dispatch structure definition. // const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; @@ -1353,17 +1464,20 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 445ead329acf8..aec5dc9c3b9c7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -6,6 +6,13 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen16[40], 32) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000001, 0x00000001, 0x00000001, 0x00000001 +}; MLAS_FORCEINLINE __m256 load_and_broadcast_4_scale_2(const float* scale) @@ -152,6 +159,208 @@ accumulate_blklen16_r2c1blk4_avx2( scale_a0, scale_a1, scale_b, acc0, acc1); } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + __m256 scale_a1_4_ps = _mm256_broadcast_ps((const __m128*)scale_a1); + __m256 scale_a1b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a1_4_ps); + __m256 scale_a1b_4_shuffle_ps = _mm256_permutevar_ps(scale_a1b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk1_avx2( + const __m128i& av00_16_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m256& acc0 +) +{ + const __m128i bv0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantBDataPtr)); + __m256 scale_a0b_1_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m128i dot00_4_epi32 = _mm_dpbusds_avx_epi32(_mm_setzero_si128(), bv0_16_epi8, av00_16_epi8); + const __m256i dot00_8_epi32 = _mm256_cvtepu32_epi64(dot00_4_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } + else +#endif + { + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + const __m256i bv0_32_epi8 = _mm256_cvtepu8_epi16(bv0_16_epi8); + const __m256i av00_32_epi8 = _mm256_cvtepu8_epi16(av00_16_epi8); + const __m256i dot00_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot00_8_epi32 = _mm256_madd_epi16(one_mask, dot00_16_epi16); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk4_avx2( const __m256i& av0_32_epi8, @@ -332,6 +541,118 @@ Q4Int8GemmR2xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av_10_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc[NCols4]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a10 * scale_b1, acc[NCols4 + 1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a10 * scale_b2, acc[NCols4 + 2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a10 * scale_b3, acc[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr+= NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( const std::byte* QuantA, const float* QuantAScale, @@ -437,6 +758,108 @@ void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0 + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc0); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx2( const std::byte* QuantA, @@ -549,6 +972,106 @@ Q4Int8GemmR1xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx2( const std::byte* QuantA, @@ -634,6 +1157,90 @@ Q4Int8GemmR1xC1BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_a0b = scale_a00 * (*QuantBScalePtr); + accumulate_q8_blklen16_r1c1blk1_avx2(av_16_epi8, QuantBDataPtr, scale_a0b, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen16Avx2( @@ -725,3 +1332,95 @@ MLAS_FORCEINLINE return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 5dab8091ce760..09d53f9b852db 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -6,6 +6,11 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen32[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; MLAS_FORCEINLINE void accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, @@ -115,6 +120,156 @@ accumulate_blklen32_r2c1blk2_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_a1b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a1_2_ps); + __m256 scale1_8_ps = _mm256_shuffle_ps(scale_a1b_2_ps, scale_a1b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11 00 11 + const __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11, 00 11 + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( @@ -196,6 +351,100 @@ accumulate_blklen32_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv0_32_epi8, combined_scale10, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + __m256 dot10_8_ps = _mm256_cvtepi32_ps(dot10_8_epi32); + acc1 = _mm256_fmadd_ps(dot10_8_ps, _mm256_set1_ps(combined_scale10), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( @@ -367,6 +616,116 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + float scale_a00 = *QuantAScalePtr; + float scale_a10 = *(QuantAScalePtr + BlockCountK); + + float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + + float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc[1], acc[NCols4 + 1]); + + float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc[2], acc[NCols4 + 2]); + + float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc[3], acc[NCols4 + 3]); + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, @@ -460,6 +819,95 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( @@ -589,6 +1037,100 @@ Q4Int8GemmXx4BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc[3]); + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( @@ -672,6 +1214,81 @@ Q4Int8GemmXxXBlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes16; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -765,6 +1382,98 @@ MLAS_FORCEINLINE return CountM; } +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + // this function is to explore larger NCols. With Avx2 it does not improve performance. // Leave it here until the same is implemented in avx512. template accumulator> diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index d4b89bd9bad2d..2bf27df2dccce 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -6,6 +6,12 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen64[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; + template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx2( @@ -76,6 +82,143 @@ accumulate_blklen64_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m256& acc0 +) +{ + __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_8_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + float scale_a1b, + __m256& acc0, + __m256& acc1 +) +{ + __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale0_8_ps = _mm256_set1_ps(scale_a0b); + __m256 scale1_8_ps = _mm256_set1_ps(scale_a1b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum0_8_epi32 = _mm256_dpbusds_avx_epi32(sum0_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + __m256i sum1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum1_8_epi32 = _mm256_dpbusds_avx_epi32(sum1_8_epi32, bv1_32_epi8, av11_32_epi8); + __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else + #endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_add_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx2( @@ -212,6 +355,105 @@ Q4Int8GemmR2xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a1b1 = (*(QuantBScalePtr + 1)) * scale_a1; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a1b2 = (*(QuantBScalePtr + 2)) * scale_a1; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + const float scale_a1b3 = (*(QuantBScalePtr + 3)) * scale_a1; + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( @@ -292,6 +534,91 @@ Q4Int8GemmR2xC1BlkLen64Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( @@ -371,6 +698,92 @@ Q4Int8GemmR1xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( @@ -447,6 +860,83 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx2( @@ -539,3 +1029,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index ea06f954c854a..c1bc00fbffa3e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -247,6 +247,99 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -337,7 +430,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -353,20 +446,49 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } +static void +SQ8BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index d79554c34c108..7ca72debd6d25 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -35,6 +35,13 @@ // bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); //} +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen128[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void dot_accumulate_1blk( const __m512i& bv0_64_epi8, @@ -139,6 +146,117 @@ accumulate_blklen128_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + float scale_a1b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, scale_a1b, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + + // row 1 + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(dot10_16_epi32, dot11_16_epi32); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_set1_ps(scale_a1b), acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen128Avx512( @@ -251,6 +369,110 @@ Q4Int8GemmR2xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + const float scale_a1b1 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 1)); + const float scale_a1b2 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 2)); + const float scale_a1b3 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen128Avx512( @@ -332,6 +554,89 @@ Q4Int8GemmR2xC1BlkLen128Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen128Avx512( @@ -411,6 +716,90 @@ Q4Int8GemmR1xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc[0]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, acc[2]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen128Avx512( @@ -487,6 +876,82 @@ Q4Int8GemmR1xC1BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen128Avx512( @@ -579,3 +1044,97 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index 03064886caf24..b720c45b637ad 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -9,7 +9,14 @@ #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen16[48], 64) = { + 0x00000000, 0x00000000, 0x00000004, 0x00000004, 0x00000001, 0x00000001, 0x00000005, 0x00000005, + 0x00000002, 0x00000002, 0x00000006, 0x00000006, 0x00000003, 0x00000003, 0x00000007, 0x00000007, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) @@ -120,6 +127,131 @@ accumulate_blklen16_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a1b_ps, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 0000111122223333 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 4444555566667777 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk8_avx512vnni( const __m512i& av0_64_epi8, @@ -214,6 +346,36 @@ accumulate_blklen16_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 01234567 00000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen16Avx512( @@ -399,6 +561,152 @@ Q4Int8GemmR2xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = BlkDataSizeInBytes * PerAccuBlk8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[NCols4 + 1]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[NCols4 + 2]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + // In A, the bytes beyond K has set to 0. + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_a0b0 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_a1b0 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a0b0, acc2[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a1b0, acc2[NCols4]); + + const float scale_a0b1 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_a1b1 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a0b1, acc2[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a1b1, acc2[NCols4 + 1]); + + const float scale_a0b2 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_a1b2 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a0b2, acc2[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a1b2, acc2[NCols4 + 2]); + + const float scale_a0b3 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float scale_a1b3 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a0b3, acc2[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a1b3, acc2[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen16Avx512( @@ -509,6 +817,108 @@ Q4Int8GemmR2C1BlkLen16Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_00, acc20); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_10, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx512( @@ -628,6 +1038,118 @@ Q4Int8GemmR1xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = PerAccuBlk8 * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc2[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx512( @@ -719,6 +1241,94 @@ Q4Int8GemmR1xC1BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -810,3 +1420,94 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 3b1096ac05ba7..f630883de92b4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -8,6 +8,15 @@ #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen32[48], 64) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000000, 0x00000000, 0x00000002, 0x00000002, + 0x00000001, 0x00000001, 0x00000003, 0x00000003, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) { @@ -115,6 +124,139 @@ accumulate_blklen32_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a1b_ps, 0); + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 00220022 11331133 + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 00000000 11111111 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 22222222 33333333 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 00220022 11331133 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk4_avx512vnni( const __m512i& av0_64_epi8, @@ -203,6 +345,38 @@ accumulate_blklen32_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + MLAS_FORCEINLINE void accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { @@ -256,6 +430,44 @@ accumulate_blklen32_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_q8_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_q8_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen32Avx512( @@ -437,6 +649,142 @@ Q4Int8GemmR2xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[NCols4 + 1]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[NCols4 + 2]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc2[1], acc2[NCols4 + 1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc2[2], acc2[NCols4 + 2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc2[3], acc2[NCols4 + 3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen32Avx512( @@ -548,8 +896,8 @@ Q4Int8GemmR2C1BlkLen32Avx512( } template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen32Avx512( +void MLAS_FORCEINLINE +Q8Int8GemmR2C1BlkLen32Avx512( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -559,69 +907,170 @@ Q4Int8GemmR1xC4BlkLen32Avx512( size_t CountN, size_t BlockCountK, const float* Bias, - size_t ldc -) + size_t ldc) { constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); constexpr size_t PerAccuBlk4 = 4; const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); - for (size_t m = 0; m < CountM; m++) { + for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; + float* SumPtr = C + m * ldc; - for (size_t n = 0; n < CountN; n += NCols4) { + for (size_t n = 0; n < CountN; n++) { const std::byte* QuantAPtr = QuantA + m * lda; const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - __m512 acc[NCols4] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); if constexpr (vnni) { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); } else { - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); } + // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; - QuantBScalePtr += PerAccuBlk4 * NCols4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; } - __m256 acc2[NCols4] = { - h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) - }; - - while (k_blks_remaining-- > 0) { + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); @@ -667,6 +1116,114 @@ Q4Int8GemmR1xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_03, acc2[3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen32Avx512( @@ -759,6 +1316,94 @@ Q4Int8GemmR1xC1BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -850,3 +1495,93 @@ MlasQ4Int8GemmKernelBlkLen32Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 72ce28d834199..33d4fde26ae5b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -6,6 +6,13 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen64[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE __m256 h_add_512(__m512 a) { @@ -125,7 +132,7 @@ dot_accumulate_2blkvnni( __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); - __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 1 0 1 0 1 0 1... __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); @@ -182,6 +189,146 @@ accumulate_blklen64_r2c1blk2_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + dot_accumulate_2blkvnni(av10_64_epi8, av11_64_epi8, scale_a1, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + const __m512 scale_a1_16_ps = _mm512_broadcast_f32x8(scale_a1_ps); + const __m512 scale_a1b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a1_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i t12 = _mm512_unpackhi_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk2_avx512( @@ -283,6 +430,112 @@ accumulate_blklen64_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx512( + const __m512i& av0_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + // row 1 + __m512i dot1_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av1_64_epi8); + __m512i dot1_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av1_64_epi8); + __m512i dot1_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_low_32_epi16); + __m512i dot1_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_high_32_epi16); + __m512i dot1_16_epi32 = _mm512_add_epi32(dot1_low_16_epi32, dot1_high_16_epi32); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx512( @@ -448,6 +701,106 @@ Q4Int8GemmR2xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen64); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx512( @@ -540,6 +893,95 @@ Q4Int8GemmR2xC1BlkLen64Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx512( @@ -633,6 +1075,94 @@ Q4Int8GemmR1xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx512( @@ -718,6 +1248,88 @@ Q4Int8GemmR1xC1BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_q8_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx512( @@ -838,3 +1450,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index c2fcd92be2364..ea5eebd854655 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -299,6 +299,99 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -319,7 +412,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -335,23 +428,52 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } +static void +SQ8BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + // // Kernel dispatch structure definition. // const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512vnni; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 02429a0c64f8e..e7df817dea34c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -6,8 +6,9 @@ // Quantized B data packing function implementation. // +template static size_t -Q4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, @@ -15,7 +16,6 @@ Q4BitGemmPackQuantBDataSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); @@ -247,6 +247,60 @@ PackQuantB( ); } +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkSize; + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + //#include static void @@ -295,6 +349,61 @@ ComputePackBlkSum( ); } +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + + // re-arrange scale to the same order as packed data + if (n_4 + 4 > N) { + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + *(QuantBScaleBegin + n_4 * BlockCountK + k_blk * 4 + n_res) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + size_t dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = n_4 * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = n_4 * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset) = QuantBScale; + } + }); +} + static void PackQuantBDataAndBlkSum( size_t N, @@ -305,7 +414,7 @@ PackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& PackedQuantB, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -322,12 +431,39 @@ PackQuantBDataAndBlkSum( } } +static void +Q8PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + // // Workspace size calculation function implementation. // static size_t -Q4BitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, @@ -353,7 +489,7 @@ Q4BitGemmPerGemmWorkspaceSize( } static size_t -Q4BitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) diff --git a/onnxruntime/core/mlas/lib/transpose.cpp b/onnxruntime/core/mlas/lib/transpose.cpp index 1ee2f90357e9e..61c379668a0a2 100644 --- a/onnxruntime/core/mlas/lib/transpose.cpp +++ b/onnxruntime/core/mlas/lib/transpose.cpp @@ -484,20 +484,20 @@ MlasTranspose8x8Block( __m128i c3 = __lsx_vilvh_h(b3, b2); __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 0], 0, 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 1], 0, 1); __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 2], 0, 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 3], 0, 1); __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 4], 0, 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 5], 0, 1); __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 6], 0, 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 7], 0, 1); } #endif diff --git a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S index 3004599bcb3d4..194f210a4aa91 100644 --- a/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/ConvSymKernelAvx2.S @@ -23,6 +23,91 @@ Abstract: .intel_syntax noprefix + .extern CheckSaturationForVPMADDUBSW + + .macro CheckSaturation VecReg1Num, VecReg2Num + +// +// Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + push rax + push rcx + push rdx + push rsi + push rdi + push r8 + push r9 + push r10 + push r11 + + sub rsp, 512 # reserve space for 16 YMM registers (32 bytes) + +// +// Save YMM registers (YMM0 to YMM15) +// + + vmovdqu [rsp], ymm0 + vmovdqu [rsp+32], ymm1 + vmovdqu [rsp+64], ymm2 + vmovdqu [rsp+96], ymm3 + vmovdqu [rsp+128], ymm4 + vmovdqu [rsp+160], ymm5 + vmovdqu [rsp+192], ymm6 + vmovdqu [rsp+224], ymm7 + vmovdqu [rsp+256], ymm8 + vmovdqu [rsp+288], ymm9 + vmovdqu [rsp+320], ymm10 + vmovdqu [rsp+352], ymm11 + vmovdqu [rsp+384], ymm12 + vmovdqu [rsp+416], ymm13 + vmovdqu [rsp+448], ymm14 + vmovdqu [rsp+480], ymm15 + + lea rdi, [rsp+32*\VecReg1Num\()] # first operand (unsigned) + lea rsi, [rsp+32*\VecReg2Num\()] # second operand (signed) + + call CheckSaturationForVPMADDUBSW + +// +// Restore YMM registers +// + + vmovdqu ymm0, [rsp] + vmovdqu ymm1, [rsp+32] + vmovdqu ymm2, [rsp+64] + vmovdqu ymm3, [rsp+96] + vmovdqu ymm4, [rsp+128] + vmovdqu ymm5, [rsp+160] + vmovdqu ymm6, [rsp+192] + vmovdqu ymm7, [rsp+224] + vmovdqu ymm8, [rsp+256] + vmovdqu ymm9, [rsp+288] + vmovdqu ymm10, [rsp+320] + vmovdqu ymm11, [rsp+352] + vmovdqu ymm12, [rsp+384] + vmovdqu ymm13, [rsp+416] + vmovdqu ymm14, [rsp+448] + vmovdqu ymm15, [rsp+480] + + add rsp, 512 # clean up the reserved stack space + +// +// Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdi + pop rsi + pop rdx + pop rcx + pop rax + + .endm + /*++ Macro Description: @@ -52,9 +137,15 @@ Implicit Arguments: .macro MultiplyAccumulateRowAvx2 Vec1Reg, Vec2Reg +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,0 +#endif vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,1 +#endif vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 diff --git a/onnxruntime/core/optimizer/bias_softmax_fusion.cc b/onnxruntime/core/optimizer/bias_softmax_fusion.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 71c8667a89b1d..04f74eb860443 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -19,24 +19,12 @@ namespace { #if !defined(ORT_MINIMAL_BUILD) namespace selectors { + const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& node) { if (!optimizer_utils::CheckOutputEdges(graph_viewer.GetGraph(), node, 1)) { return nullptr; } - const Node* next_node = &*node.OutputNodesBegin(); - // ensure that the target node also has only one input that is not an initializer - const size_t input_edges_total = next_node->GetInputEdgesCount(); - int non_const_edges = 0; - for (size_t edge_idx = 0; edge_idx < input_edges_total; ++edge_idx) { - if (!graph_utils::NodeArgIsConstant(graph_viewer.GetGraph(), *next_node->InputDefs()[edge_idx])) { - ++non_const_edges; - } - } - if (non_const_edges > 1) { - return nullptr; - } else { - return next_node; - } + return &*node.OutputNodesBegin(); } bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { @@ -143,6 +131,7 @@ class ConvActivationSelector : public NodeSelector { #endif // !defined(ORT_MINIMAL_BUILD) namespace actions { + using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { @@ -217,36 +206,6 @@ class FuseConvActivationAction : public ReplaceWithNew { } }; -class FuseConvAddRelu : public ReplaceWithNew { - private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } - - std::string Domain(const RuntimeState&) const override { return kMSDomain; } - - NodeAttributes ExtraAttributes(const RuntimeState&) const override { - NodeAttributes extra_fused_conv_attributes; - utils::SetNodeAttribute(utils::MakeAttribute("activation", "Relu"), extra_fused_conv_attributes); - return extra_fused_conv_attributes; - } - - std::vector ValueMoves(const RuntimeState& state) const override { - const auto& conv = state.selected_nodes.Target(); - - ORT_ENFORCE(conv.GetOutputEdgesCount() == 1 && conv.OutputNodesBegin()->OpType() == "Add", - "Expected Conv then Add."); - const auto add_input_idx = 1 - conv.OutputEdgesBegin()->GetDstArgIndex(); - - const auto conv_location = NTO::NodeLocation{NTO::NodeType::kTarget, 0}; - const auto add_location = NTO::NodeLocation{NTO::NodeType::kOutput, 0}; - const auto relu_location = NTO::NodeLocation{NTO::NodeType::kOutput, 1}; - - return { - MoveAll(conv_location, ArgType::kInput), // move all inputs from conv - MoveAndAppend(add_location, ArgType::kInput, add_input_idx, ArgType::kInput), // append add input - MoveAll(relu_location, ArgType::kOutput), // move all outputs from relu - }; - } -}; } // namespace actions void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2754eebf75421..c2cdf360ad986 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -90,7 +90,10 @@ bool ConvertNodeLayout(const api::NodeRef& node) { } #endif -#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +// TODO: We don't need to check USE_CUDA || USE_CUDA_PROVIDER_INTERFACE in this function because we're already +// checking if the node is assigned to the desired EP (e.g., CUDA EP). We should only need to check +// ENABLE_CUDA_NHWC_OPS. +#if (defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)) && ENABLE_CUDA_NHWC_OPS if (node.GetExecutionProviderType() == kCudaExecutionProvider) { if (layout_sensitive_ops.count(node.OpType())) { const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps(); @@ -101,6 +104,18 @@ bool ConvertNodeLayout(const api::NodeRef& node) { } #endif +// TODO: We don't really need EP pre-processor macros in this function because we're already checking if the +// node is assigned to the desired EP (e.g., QNN EP). There's nothing about this code that absolutely requires +// conditional compilation. +#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) + if (node.GetExecutionProviderType() == kQnnExecutionProvider) { + if (node.OpType() == "Upsample") { + // Upsample is translated to QNN's Resize, which requires the NHWC layout for processing. + return true; + } + } +#endif + return layout_sensitive_ops.count(node.OpType()) != 0; } } // namespace diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h index 6e627ecc0d7e1..afd045fa4d0db 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h @@ -24,6 +24,7 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 13}, OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 19}, OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "DequantizeLinear", 23}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Gather", 13}, @@ -33,21 +34,26 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { OpIdentifierWithStringViews{kOnnxDomain, "Identity", 16}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 19}, OpIdentifierWithStringViews{kOnnxDomain, "Identity", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "Identity", 23}, OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 10}, OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 13}, OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 19}, OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "QuantizeLinear", 23}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 13}, OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "Squeeze", 23}, OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 13}, OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "Transpose", 23}, OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 1}, OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 11}, OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 13}, OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 21}, + OpIdentifierWithStringViews{kOnnxDomain, "Unsqueeze", 23}, #if !defined(DISABLE_CONTRIB_OPS) // kMSDomain ops diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 338722fb00782..4b9259c080da3 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -26,7 +26,22 @@ struct ExtractScalarAsFloatDispatchTarget { } }; -std::optional GetScalarConstantInitializer(const Graph& graph, const NodeArg& node_arg) { +std::optional GetTensorShape(const NodeArg& node_arg) { + const auto* shape_proto = node_arg.Shape(); + if (!shape_proto) { + return std::nullopt; + } + + return utils::GetTensorShapeFromTensorShapeProto(*shape_proto); +} + +// Note: In this context, we consider a scalar to be a single element tensor with rank up to `max_rank`. +// The stricter definition of a scalar having an empty shape would work too, but we can accept a bit more than that. +// In the case where `node_arg` has a non-empty shape, i.e., any dimensions with length 1, we only consider it a scalar +// if it does not have any broadcasting effect on the Mul or Div output shape. +// Because the dimension lengths can only be 1, we only need to consider additional leading dimensions being prepended. +// `max_rank` should be set to the rank of the other Mul or Div input to avoid that. +std::optional GetScalarConstantInitializer(const Graph& graph, const NodeArg& node_arg, size_t max_rank) { const auto* initializer = graph_utils::GetConstantInitializer(graph, node_arg.Name()); if (!initializer) { @@ -34,12 +49,12 @@ std::optional GetScalarConstantInitializer(const Graph& graph, const Node return {}; } - const auto* shape = node_arg.Shape(); + const auto shape = GetTensorShape(node_arg); ORT_ENFORCE( - shape, + shape.has_value(), "Constant initializer NodeArg shape should not be null. NodeArg: ", node_arg.Name()); - if (utils::GetTensorShapeFromTensorShapeProto(*shape).Size() != 1) { + if (shape->Size() != 1 || shape->NumDimensions() > max_rank) { // not a scalar return {}; } @@ -73,7 +88,10 @@ std::optional> GetScaleFromNode( if (is_excluded_initializer(scale_reciprocal_node_arg)) return std::nullopt; - const auto divisor = GetScalarConstantInitializer(graph, scale_reciprocal_node_arg); + const NodeArg& other_node_arg = *div_inputs[1 - scale_reciprocal_arg_index]; + const auto max_rank = GetTensorShape(other_node_arg).value_or(TensorShape{}).NumDimensions(); + + const auto divisor = GetScalarConstantInitializer(graph, scale_reciprocal_node_arg, max_rank); if (!divisor.has_value()) return std::nullopt; @@ -90,7 +108,10 @@ std::optional> GetScaleFromNode( if (is_excluded_initializer(scale_node_arg)) continue; - const auto multiplier = GetScalarConstantInitializer(graph, scale_node_arg); + const NodeArg& other_node_arg = *mul_inputs[1 - scale_arg_index]; + const auto max_rank = GetTensorShape(other_node_arg).value_or(TensorShape{}).NumDimensions(); + + const auto multiplier = GetScalarConstantInitializer(graph, scale_node_arg, max_rank); if (!multiplier.has_value()) continue; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index fab5078921c7a..fe5874d067b95 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -204,12 +204,12 @@ bool IsQOrDQScalePositiveConstantScalar( #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) bool MatchQNode(const Node& node) { - return graph_utils::IsSupportedOptypeVersionAndDomain(node, QOpName, {10, 13, 19, 21}) || + return graph_utils::IsSupportedOptypeVersionAndDomain(node, QOpName, {10, 13, 19, 21, 23}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, QOpName, {1}, kMSDomain); } bool MatchDQNode(const Node& node) { - return graph_utils::IsSupportedOptypeVersionAndDomain(node, DQOpName, {10, 13, 19, 21}) || + return graph_utils::IsSupportedOptypeVersionAndDomain(node, DQOpName, {10, 13, 19, 21, 23}) || graph_utils::IsSupportedOptypeVersionAndDomain(node, DQOpName, {1}, kMSDomain); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index a451e3ad60e94..83c5d7bc8d92a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -43,6 +43,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph continue; } + // Require that the node's output is consumed by a single QuantizeLinear node. + // Otherwise, if only the inputs are quantized, but not the output, then this node group would not + // be considered a QDQ node unit anyway. + std::vector children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name()); + if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) { + continue; + } + Node& dq_0 = *graph.GetNode(parent_node_0->Index()); Node* dq_1 = nullptr; const ONNX_NAMESPACE::TensorProto* weight_proto = nullptr; diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index e4d59ea732d1e..ae321cca09e82 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -10,6 +10,7 @@ #include #include #include +#include #include namespace onnx_transpose_optimization { @@ -465,7 +466,7 @@ class GraphRef { } // namespace api constexpr int64_t kMinSupportedOpset = 7; -constexpr int64_t kMaxSupportedOpset = 22; +constexpr int64_t kMaxSupportedOpset = 23; // enum of results that a CostCheckFn can return. enum class CostCheckResult { diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index c7e11de34858a..4ea54f4db9e87 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -369,7 +369,27 @@ bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_outp return false; } - return node.GetOutputEdgesCount() == expected_output_edges; + if (node.GetOutputEdgesCount() != expected_output_edges) { + return false; + } + + // Verify no output edges go to implicit inputs. + // An output edge to an implicit input implies the possibility of consumers in a subgraph. + // It is non-trivial to determine the actual number of corresponding edges in the subgraph. + // We also don't want to fuse part of a subgraph. This function is likely used from graph transformers to check if + // nodes can be fused. + // We'll just disallow output edges to implicit inputs for simplicity. + for (auto output_edge_it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); + output_edge_it != end; ++output_edge_it) { + const auto& output_node = output_edge_it->GetNode(); + const auto output_node_input_arg_idx = static_cast(output_edge_it->GetDstArgIndex()); + const bool is_implicit_input_to_output_node = output_node_input_arg_idx >= output_node.InputDefs().size(); + if (is_implicit_input_to_output_node) { + return false; + } + } + + return true; } bool IsScalar(const NodeArg& input_arg) { diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index b0da4becb0146..857640f861238 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -163,10 +163,12 @@ bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntim */ bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max); -/** Check whether node's output edges count is expected. -@remarks graph output is not included in output edges, and this node shall not have graph output. - A node with graph output cannot be fused unless the graph output also exists in outputs of fused node. -@returns false when the node has graph output, or number of output edges are not expected. +/** Check whether `node` has expected outputs. +@remarks Graph outputs are not included in output edges, and this node should not have a graph output. + A node with a graph output cannot be fused unless the graph output also exists in outputs of the fused node. + Output edges to implicit inputs are also disallowed as we don't want to fuse with part of a subgraph. +@returns False when `node` has a graph output or an output edge to an implicit input, or when the number of `node`'s + output edges is not `expected_output_edges`. */ bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_output_edges); diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index f3520b4f7f7f5..fed61854860f0 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -75,45 +75,41 @@ struct SymbolHelper { SymbolHelper() = default; - static constexpr size_t kInitialBufferSize = sizeof(SYMBOL_INFO) + MAX_SYM_NAME; - - bool LoookupSymAndInitialize(const ULONG_PTR address, char* buffer, size_t buffer_size, SYMBOL_INFO* symbol) { - if (SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { + bool LookupSymAndInitialize(const void* address, SYMBOL_INFO* symbol, std::ostream& message) { + if (SymFromAddr(process_handle_, reinterpret_cast(address), 0, symbol) != TRUE) { if (GetLastError() == ERROR_INVALID_HANDLE) { // Try to initialize first - if (!InitializeWhenNeeded() || SymFromAddr(process_handle_, address, 0, symbol) != TRUE) { - _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + if (!InitializeWhenNeeded() || + SymFromAddr(process_handle_, reinterpret_cast(address), 0, symbol) != TRUE) { + message << "0x" << address << " (Unknown symbol)"; return false; } } else { - _snprintf_s(buffer, buffer_size, _TRUNCATE, "0x%08IX (Unknown symbol)", address); + message << "0x" << address << " (Unknown symbol)"; return false; } } return true; } - void Lookup(std::string& string, const ULONG_PTR address) { - alignas(SYMBOL_INFO) char buffer[kInitialBufferSize] = {0}; - SYMBOL_INFO* symbol = reinterpret_cast(buffer); + void Lookup(const void* address, std::ostream& message) { + SYMBOL_INFO_PACKAGE symbol_info_package{}; + SYMBOL_INFO* symbol = &symbol_info_package.si; symbol->SizeOfStruct = sizeof(SYMBOL_INFO); - symbol->MaxNameLen = MAX_SYM_NAME; + symbol->MaxNameLen = std::size(symbol_info_package.name); - if (!LoookupSymAndInitialize(address, buffer, kInitialBufferSize, symbol)) { - string.append(buffer); + if (!LookupSymAndInitialize(address, symbol, message)) { return; } Line line; DWORD displacement; - if (SymGetLineFromAddr(process_handle_, address, &displacement, &line) == false) { - _snprintf_s(buffer, _TRUNCATE, "(unknown file & line number): %s", symbol->Name); - string.append(buffer); + if (SymGetLineFromAddr(process_handle_, reinterpret_cast(address), &displacement, &line) == false) { + message << "(unknown file & line number): " << symbol->Name; return; } - _snprintf_s(buffer, _TRUNCATE, "%s(%d): %s", line.FileName, static_cast(line.LineNumber), symbol->Name); - string.append(buffer); + message << line.FileName << "(" << line.LineNumber << "): " << symbol->Name; } struct Line : IMAGEHLP_LINE { @@ -221,17 +217,17 @@ Memory_LeakCheck::~Memory_LeakCheck() { const MemoryBlock& block = *static_cast(entry.lpData); const BYTE* pBlock = static_cast(entry.lpData) + sizeof(MemoryBlock); - std::string string; - char buffer[1024]; - _snprintf_s(buffer, _TRUNCATE, "%Iu bytes at location 0x%08IX\n", entry.cbData - sizeof(MemoryBlock), - UINT_PTR(pBlock)); - string.append(buffer); + std::ostringstream message; + message << (entry.cbData - sizeof(MemoryBlock)) << " bytes at location 0x" << static_cast(pBlock) + << "\n"; for (auto& p : block.m_pTraces) { if (!p) break; - symbols.Lookup(string, reinterpret_cast(p)); - string.push_back('\n'); + symbols.Lookup(p, message); + message << "\n"; } + const std::string string = message.str(); + // Google test has memory leaks that they haven't fixed. One such issue is tracked here: https://github.com/google/googletest/issues/692 // // In gtest-port.cc in function: static ThreadIdToThreadLocals* GetThreadLocalsMapLocked() @@ -271,12 +267,8 @@ Memory_LeakCheck::~Memory_LeakCheck() { if (leaked_bytes) { DebugPrint("-----Ending Heap Trace-----\n\n"); - std::string string; - char buffer[1024]; - _snprintf_s(buffer, _TRUNCATE, "%d bytes of memory leaked in %d allocations", static_cast(leaked_bytes), static_cast(leak_count)); - string.append(buffer); - - std::cout << "\n----- MEMORY LEAKS: " << string.c_str() << "\n"; + std::cout << "\n----- MEMORY LEAKS: " << leaked_bytes << " bytes of memory leaked in " + << leak_count << " allocations\n"; if (!IsDebuggerPresent()) { exit(-1); } diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/acl/acl_fwd.h b/onnxruntime/core/providers/acl/acl_fwd.h old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/acl/acl_provider_factory.cc b/onnxruntime/core/providers/acl/acl_provider_factory.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/acl/math/gemm.cc b/onnxruntime/core/providers/acl/math/gemm.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.h b/onnxruntime/core/providers/acl/nn/batch_norm.h old mode 100755 new mode 100644 diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index dfa01c8187741..0cb56fbe4902b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -41,9 +41,7 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims))); - int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; - // the output of ArgMax must be int32 - AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } else { auto* coreml_argmax = layer->mutable_argmax(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 9e7fcd788664c..44d6845fc663c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -115,8 +115,9 @@ bool BaseOpBuilder::IsInputDtypeSupport(const Node& node, size_t idx, } #if CAN_BUILD_COREML6_OR_LATER - // only MLProgram support FP16 - if (input_params.create_mlprogram && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + // only MLProgram support FP16 and INT64 + if (input_params.create_mlprogram && (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64)) { return true; } #endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 153ae841b238f..1f54c894d1445 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -46,7 +46,7 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger) const; virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 21; } + virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 23; } bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index d7c78e05362ed..2a791cc71523f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -54,6 +54,17 @@ bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger y_shape_proto->dim().begin(), y_shape_proto->dim().end(), dim_eq); } + +bool ShouldUseFloorDiv(const Node& node, const logging::Logger& logger) { + // since ONNX spec requires both inputs to have the same type, we only need + // to check the first input type + const auto& input0 = *node.InputDefs()[0]; + int32_t input_type0 = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + GetType(input0, input_type0, logger); + + return input_type0 == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + input_type0 == ONNX_NAMESPACE::TensorProto_DataType_INT64; +} } // namespace static std::vector InferOutputShape(const std::vector& a, const std::vector& b) { @@ -131,9 +142,13 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else if (op_type == "Sub") { coreml_op_type = "sub"; } else if (op_type == "Div") { - // we support fp32/fp16 currently. when we add support for integers we need to check the type and use - // "floor_div" or "real_div" accordingly - coreml_op_type = "real_div"; + // Use "floor_div" op for integer division (int32 or int64) + // use "real_div" for float division (fp16 or fp32) + if (ShouldUseFloorDiv(node, logger)) { + coreml_op_type = "floor_div"; + } else { + coreml_op_type = "real_div"; + } } else if (op_type == "Pow") { coreml_op_type = "pow"; } else { diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 684653aa21273..08098056f120a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -261,9 +261,10 @@ MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) { case ONNX_NAMESPACE::TensorProto_DataType_INT16: return MILSpec::DataType::INT16; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return MILSpec::DataType::INT32; case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return MILSpec::DataType::INT64; + // CoreML only supports int32 for its operations and can only produce int32 values so + // we convert any int64 to int32. + return MILSpec::DataType::INT32; case ONNX_NAMESPACE::TensorProto_DataType_UINT8: return MILSpec::DataType::UINT8; @@ -367,8 +368,7 @@ void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::st SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(element_type), shape, /*convert_scalar*/ true); } -void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output, - std::optional override_element_type) { +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) { auto& outputs = *op.mutable_outputs(); auto& output_arg = *outputs.Add(); output_arg.set_name(output.Name()); @@ -376,10 +376,7 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp MILSpec::ValueType& value = *output_arg.mutable_type(); MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); - auto elem_type = override_element_type ? *override_element_type - : output.TypeAsProto()->tensor_type().elem_type(); - - SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(elem_type), output.Shape(), /*convert_scalar*/ true); + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), output.Shape(), /*convert_scalar*/ true); } void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index b72b66362b014..8f05e670acc77 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -98,6 +98,7 @@ COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() { // The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value. // Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally +// This method also automatically converts int64 to int32 since only int32 is supported for CoreML operations. COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type); /// @@ -156,12 +157,7 @@ void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::st /// /// Operation to update. /// NodeArg with details of output to add. -/// -/// Override the element type. Only set to handle cases where we believe the data at runtime will be int32 but -/// the original ONNX node has type int64. -/// -void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output, - std::optional override_element_type = std::nullopt); +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); /// /// Add pad_type and pad values. diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 8abee92451338..e0665f5c2a5ec 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -44,7 +44,6 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model // CoreML operators can only produce int32 and not int64 values. // Due to that there should be no actual int64 values inside the CoreML model and we can infer any // ONNX_NAMESPACE::TensorProto::INT64 values to be int32. - cast_to_type = ONNX_NAMESPACE::TensorProto::INT32; } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) { to_dtype = "fp32"; } else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) { @@ -69,7 +68,7 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model if (op_type == "cast") { AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype))); } - AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type); + AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 3e691cc1745b0..70d3276059f90 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -30,20 +30,11 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { + const logging::Logger& /*logger*/) const { if (model_builder.CreateMLProgram()) { using CoreML::Specification::MILSpec::Operation; std::unique_ptr op = model_builder.CreateOperation(node, "gather"); - std::optional output_datatype; - - int32_t input_type; - ORT_RETURN_IF_NOT(GetType(*node.InputDefs()[0], input_type, logger), "Failed to get input type"); - - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; - } - const auto axis = GetAxisAttribute(node); // coreml docs claims validate_indices is optional but in practice it is required const auto validate_indices = false; @@ -51,7 +42,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices)); - AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); // output + AddOperationOutput(*op, *node.OutputDefs()[0]); // output model_builder.AddOperation(std::move(op)); } else { auto layer = model_builder.CreateNNLayer(node); diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc index 99d6f01cb8c5b..1ca8ce46857ce 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc @@ -150,6 +150,14 @@ bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParam LOGS(logger, VERBOSE) << "constant_value must be a constant initializer."; return false; } + + int32_t constant_value_type; + GetType(*input_defs[2], constant_value_type, logger); + + if (constant_value_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "Only float constant_value is supported, got type: " << constant_value_type; + return false; + } } { diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index d1c87b033d323..2471833e84375 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -56,10 +56,10 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::vector sizes = {size}; AddOperationInput(*slice_op, "begin", model_builder.AddConstant(slice_op->type(), "begin", starts)); AddOperationInput(*slice_op, "size", model_builder.AddConstant(slice_op->type(), "size", sizes)); - AddOperationOutput(*slice_op, *node.OutputDefs()[0], output_datatype); + AddOperationOutput(*slice_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(slice_op)); } else { - AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype); + AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); } } else { @@ -127,7 +127,8 @@ bool ShapeOpBuilder::HasSupportedInputsImpl(const Node& node, if (input_params.create_mlprogram) { if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || - input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64)) { return true; } else { LOGS(logger, VERBOSE) << "[" << node.OpType() diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 368e47e40f831..bf72fbbf1ace4 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -143,21 +143,6 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } } - // Int32, float and float16 are supported by CoreML slice_by_index. - // We convert any int64 model input to int32 when running the CoreML model for the partition. - // Any other integer data created at runtime is the output from CoreML operations, and should int32 not int64. - // Based on that, we assume that the actual input when running will be int32, so we override the output data - // type to reflect this. - // If we were to leave it as TensorProto_DataType_INT64 the CoreML model would be invalid. - std::optional output_datatype; - - int32_t input_type; - ORT_RETURN_IF_NOT(GetType(*node.InputDefs()[0], input_type, logger), "Failed to get input type"); - - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32; - } - auto op = model_builder.CreateOperation(node, "slice_by_index"); auto begin = model_builder.AddConstant(op->type(), "begin", AsSpan(compute_metadata.starts_)); @@ -173,7 +158,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const AddOperationInput(*op, "begin_mask", begin_mask); AddOperationInput(*op, "end_mask", end_mask); - AddOperationOutput(*op, *output_defs[0], output_datatype); + AddOperationOutput(*op, *output_defs[0]); model_builder.AddOperation(std::move(op)); diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index 81bef11906b74..92f0f2bb5fc3d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -58,8 +58,8 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } -void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, - const Node& node, const logging::Logger& logger) { +void HandleUnsqueezeScalarInput(ModelBuilder& model_builder, + const Node& node, const logging::Logger& logger) { const auto& input_defs(node.InputDefs()); TensorShapeVector axes; GetAxes(model_builder, node, axes); @@ -86,13 +86,14 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; -#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64 - // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input + // MLProgram does not support scalar values -- we convert the scalars to 1D tensors. + // So there is a bug when we attempt to unsqueeze what is a + // scalar value in the ONNX graph to a 1D tensor. if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) { - HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger); + HandleUnsqueezeScalarInput(model_builder, node, logger); return Status::OK(); } -#endif + std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims"; std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index c8df88e38e096..1f0fc0056689e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -382,10 +382,7 @@ MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tenso MILSpec::ValueType& value_type = *value.mutable_type(); MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype(); MILSpec::DataType data_type = OnnxDataTypeToMILSpec(tensor_proto.data_type()); - MILSpec::DataType converted_data_type = data_type == MILSpec::DataType::INT64 - ? MILSpec::DataType::INT32 - : data_type; - tensor_type.set_datatype(converted_data_type); + tensor_type.set_datatype(data_type); tensor_type.set_rank(tensor_proto.dims().size()); for (const auto& dim : tensor_proto.dims()) { @@ -931,11 +928,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i // the model inputs need to be wired up as args to the 'main' function. auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true); - // we need to convert int64 to int32 here as well - if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - tensor_value_type.mutable_type()->mutable_tensortype()->set_datatype( - OnnxDataTypeToMILSpec(ONNX_NAMESPACE::TensorProto_DataType_INT32)); - } + // Handle conversion from int64 to int32 + tensor_value_type.mutable_type()->mutable_tensortype()->set_datatype( + OnnxDataTypeToMILSpec(data_type)); tensor_value_type.set_name(name); diff --git a/onnxruntime/core/providers/cpu/controlflow/if.cc b/onnxruntime/core/providers/cpu/controlflow/if.cc index 8b17c297e1e5a..6b6bc0e0e4af7 100644 --- a/onnxruntime/core/providers/cpu/controlflow/if.cc +++ b/onnxruntime/core/providers/cpu/controlflow/if.cc @@ -115,9 +115,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // uint4 and int4 support was added. // TODO(adrianlizarraga): Implement int4 and uint4 support. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( If, 21, + 22, + KernelDefBuilder() + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypesIRv9()), + If); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + If, + 23, KernelDefBuilder() .TypeConstraint("B", DataTypeImpl::GetTensorType()) .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypesIRv9()), diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.cc b/onnxruntime/core/providers/cpu/controlflow/loop.cc index b33b1f189594b..a1ec61a6b383f 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.cc +++ b/onnxruntime/core/providers/cpu/controlflow/loop.cc @@ -142,9 +142,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added int4 and uint4 support. // TODO(adrianlizarraga): Implement int4 and uint4 support. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Loop, 21, + 22, + KernelDefBuilder() + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypesIRv9()), + Loop); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Loop, + 23, KernelDefBuilder() .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) diff --git a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc index 24d233c0594fa..3b31fdd06838d 100644 --- a/onnxruntime/core/providers/cpu/controlflow/scan_9.cc +++ b/onnxruntime/core/providers/cpu/controlflow/scan_9.cc @@ -534,8 +534,19 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(Scan, // Opset 21 starts to support 4-bit int types for the type constraint "V" // TODO(adrianlizarraga): Implement int4 and uint4 support. +ONNX_CPU_OPERATOR_VERSIONED_KERNEL(Scan, + 21, + 22, + KernelDefBuilder() + // 'I' is in the ONNX spec but is not actually used for any inputs or outputs + // .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + Scan<9>); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. ONNX_CPU_OPERATOR_KERNEL(Scan, - 21, + 23, KernelDefBuilder() // 'I' is in the ONNX spec but is not actually used for any inputs or outputs // .TypeConstraint("I", DataTypeImpl::GetTensorType()) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 94d63362907de..e9cbcb253b304 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -32,10 +32,10 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) std::vector CPUExecutionProvider::CreatePreferredAllocators() { const bool create_arena = DoesCpuAllocatorSupportArenaUsage() ? info_.create_arena : false; - AllocatorCreationInfo device_info{[](int) { return std::make_unique(); }, - DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena}; + AllocatorCreationInfo device_info_cpu{[](int) { return std::make_unique(); }, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena}; - return std::vector{CreateAllocator(device_info)}; + return std::vector{CreateAllocator(device_info_cpu)}; } // Forward declarations of op kernels @@ -1197,47 +1197,47 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Re class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit); // Opset 21 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, ConstantOfShape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Identity); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Scan); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Shape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Size); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Unsqueeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, If); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Loop); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Flatten); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, ConstantOfShape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Identity); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Scan); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Size); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Loop); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, uint16_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, int16_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Int4x2, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, UInt4x2, DequantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E5M2, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E5M2FNUZ, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E4M3FN, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E4M3FNUZ, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E5M2, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E5M2FNUZ, DequantizeLinear); #endif -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint16_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, uint8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, int8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, uint16_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, int16_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Int4x2, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, UInt4x2, QuantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E5M2, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E5M2FNUZ, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E4M3FN, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E4M3FNUZ, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E5M2, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, 22, Float8E5M2FNUZ, QuantizeLinear); #endif +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul); // Opset 22 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Acos); @@ -1289,6 +1289,47 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, MLFloat16, AveragePool); #endif +// Opset 23 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, uint16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int16_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Int4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, UInt4x2, DequantizeLinear); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E4M3FNUZ, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E5M2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E5M2FNUZ, DequantizeLinear); +#endif +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, uint8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, uint16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int16_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Int4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, UInt4x2, QuantizeLinear); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E4M3FNUZ, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E5M2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Float8E5M2FNUZ, QuantizeLinear); +#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Identity); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Loop); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Scan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Size); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -3094,70 +3135,70 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 21 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 22 BuildKernelCreateInfo, @@ -3206,8 +3247,69 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - }; + // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); if (info.kernel_def != nullptr) { // filter disabled entries where type is void diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index ce9780031a250..badbf1f914fd2 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -81,7 +81,7 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status NonMaxSuppressionBase__PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) override { return NonMaxSuppressionBase::PrepareCompute(ctx, pc); } Status NonMaxSuppressionBase__GetThresholdsFromInputs(const PrepareContext& pc, int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold) override { return NonMaxSuppressionBase::GetThresholdsFromInputs(pc, max_output_boxes_per_class, iou_threshold, score_threshold); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) // From cpu/tensor/size.h (direct) Status Size__Compute(const Size* p, OpKernelContext* context) override { return p->Size::Compute(context); } // From cpu/tensor/scatter_nd.h (direct) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index eb1569c3e499e..9e49f068c680c 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -38,7 +38,7 @@ struct ProviderHostCPU { virtual Status NonMaxSuppressionBase__PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) = 0; virtual Status NonMaxSuppressionBase__GetThresholdsFromInputs(const PrepareContext& pc, int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold) = 0; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) // From cpu/tensor/size.h virtual Status Size__Compute(const Size* p, OpKernelContext* context) = 0; @@ -254,7 +254,7 @@ struct ProviderHostCPU { extern ProviderHostCPU& g_host_cpu; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) namespace GatherElements { inline Status ValidateInputShapes(const TensorShape& input_data_shape, const TensorShape& indices_shape, @@ -336,7 +336,7 @@ inline Status ExecuteTritonOpByFuncName(OpKernelContext* p_ctx, const std::strin } // namespace contrib #endif // ENABLE_TRITON -#endif // USE_CUDA || USE_ROCM +#endif // USE_CUDA || USE_CUDA_PROVIDER_INTERFACE || USE_ROCM #endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index 24b028b8561f4..3ae5a15275845 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -19,6 +19,10 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 21, Output, 0, ConstantOfShapeDefaultOutputTypesOpset21); +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 23, Output, 0, + ConstantOfShapeDefaultOutputTypesOpset23); + // pytorch converter uses ConstantOfShape with int64 to create Pad input // https://github.com/pytorch/pytorch/blob/044b519a80459f6787f6723c1c091a18b153d184/torch/onnx/symbolic_opset11.py#L449 ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( @@ -33,6 +37,8 @@ using EnabledOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); +// ConstantOfShape usually updates the output type list, which is why +// we have a separate type list for it when the opset is updated. using EnabledOutputTypesOpset20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0); @@ -41,6 +47,10 @@ using EnabledOutputTypesOpset21 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 21, Output, 0); +using EnabledOutputTypesOpset23 = + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 23, Output, 0); + class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {} @@ -103,12 +113,24 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( BuildKernelDefConstraintsFromTypeList()), ConstantOfShape); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ConstantOfShape, 21, + 22, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()), ConstantOfShape); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + ConstantOfShape, + 23, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", + BuildKernelDefConstraintsFromTypeList()), + ConstantOfShape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index 2e4e1730d5e6a..ffd954f13e568 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -52,6 +52,20 @@ using ConstantOfShapeDefaultOutputTypesOpset21 = uint8_t, uint16_t, uint32_t, uint64_t, bool>; +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +using ConstantOfShapeDefaultOutputTypesOpset23 = + TypeList< + BFloat16, + MLFloat16, + float, double, +#if !defined(DISABLE_FLOAT8_TYPES) + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, +#endif + int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t, + bool>; + template class ConstantOfShapeBase { protected: diff --git a/onnxruntime/core/providers/cpu/nn/flatten.cc b/onnxruntime/core/providers/cpu/nn/flatten.cc index f9deef72c40c3..f968b56b83311 100644 --- a/onnxruntime/core/providers/cpu/nn/flatten.cc +++ b/onnxruntime/core/providers/cpu/nn/flatten.cc @@ -43,11 +43,23 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. // TODO(adrianlizarraga): Add support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Flatten, 21, + 22, + KernelDefBuilder() + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + Flatten); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Flatten, + 23, KernelDefBuilder() .Alias(0, 0) .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), Flatten); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 3d3e831a12d13..adb2aee171f39 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -120,7 +120,7 @@ static void PrepareForQDQ(const TensorShape& input_shape, #define REGISTER_DEQUANTIZELINEAR(T) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ DequantizeLinear, \ - 21, \ + 23, \ T, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -128,11 +128,11 @@ static void PrepareForQDQ(const TensorShape& input_shape, DataTypeImpl::GetTensorType()}), \ DequantizeLinear); -#define REGISTER_DEQUANTIZELINEAR_VERSIONED(T) \ +#define REGISTER_DEQUANTIZELINEAR_VERSIONED(T, start_version, end_version) \ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ DequantizeLinear, \ - 19, \ - 20, \ + start_version, \ + end_version, \ T, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -159,8 +159,8 @@ static void PrepareForQDQ(const TensorShape& input_shape, .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DequantizeLinear); -// Opset 21 added 16-bit and 4-bit int to DQ. -// TODO(adrianlizarraga): Also support 4-bit int types and 'block' quantization. +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. REGISTER_DEQUANTIZELINEAR(int8_t) REGISTER_DEQUANTIZELINEAR(uint8_t) REGISTER_DEQUANTIZELINEAR(int16_t) @@ -175,15 +175,31 @@ REGISTER_DEQUANTIZELINEAR(Float8E5M2) REGISTER_DEQUANTIZELINEAR(Float8E5M2FNUZ) #endif +// Opset 21 added 16-bit and 4-bit int to DQ. +// TODO(adrianlizarraga): Also support 4-bit int types and 'block' quantization. +REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(int16_t, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(uint16_t, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Int4x2, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(UInt4x2, 21, 22) +#if !defined(DISABLE_FLOAT8_TYPES) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FN, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2, 21, 22) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 21, 22) +#endif + // Opset 19 added 8-bit float inputs and 16-bit float outputs to DQ. -REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t) -REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t) -REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t) +REGISTER_DEQUANTIZELINEAR_VERSIONED(int8_t, 19, 20) +REGISTER_DEQUANTIZELINEAR_VERSIONED(uint8_t, 19, 20) +REGISTER_DEQUANTIZELINEAR_VERSIONED(int32_t, 19, 20) #if !defined(DISABLE_FLOAT8_TYPES) -REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FN) -REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ) -REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2) -REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FN, 19, 20) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 19, 20) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2, 19, 20) +REGISTER_DEQUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 19, 20) #endif // Before opset 19, DQ only supported int8, uint8 and int32. @@ -540,7 +556,7 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { #define REGISTER_QUANTIZELINEAR(T) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ QuantizeLinear, \ - 21, \ + 23, \ T, \ KernelDefBuilder() \ .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), \ @@ -548,11 +564,11 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ QuantizeLinear); -#define REGISTER_QUANTIZELINEAR_VERSIONED(T) \ +#define REGISTER_QUANTIZELINEAR_VERSIONED(T, start_version, end_version) \ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ QuantizeLinear, \ - 19, \ - 20, \ + start_version, \ + end_version, \ T, \ KernelDefBuilder() \ .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), \ @@ -581,15 +597,14 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ QuantizeLinear); -// Opset 21 added 16-bit and 4-bit int support to Q ops. -// TODO(adrianlizarraga): Support int4 and block quantization. +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. REGISTER_QUANTIZELINEAR(int8_t) REGISTER_QUANTIZELINEAR(uint8_t) REGISTER_QUANTIZELINEAR(int16_t) REGISTER_QUANTIZELINEAR(uint16_t) REGISTER_QUANTIZELINEAR(Int4x2) REGISTER_QUANTIZELINEAR(UInt4x2) - #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_QUANTIZELINEAR(Float8E4M3FN) REGISTER_QUANTIZELINEAR(Float8E4M3FNUZ) @@ -597,15 +612,29 @@ REGISTER_QUANTIZELINEAR(Float8E5M2) REGISTER_QUANTIZELINEAR(Float8E5M2FNUZ) #endif -// Opset 19 added 8-bit floats to Q ops. -REGISTER_QUANTIZELINEAR_VERSIONED(int8_t) -REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t) +// Opset 21 added 16-bit and 4-bit int support to Q ops. +// TODO(adrianlizarraga): Support int4 and block quantization. +REGISTER_QUANTIZELINEAR_VERSIONED(int8_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(int16_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(uint16_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(Int4x2, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(UInt4x2, 21, 22) +#if !defined(DISABLE_FLOAT8_TYPES) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FN, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 21, 22) +#endif +// Opset 19 added 8-bit floats to Q ops. +REGISTER_QUANTIZELINEAR_VERSIONED(int8_t, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t, 19, 20) #if !defined(DISABLE_FLOAT8_TYPES) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FN) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FN, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 19, 20) #endif // Before opset 19, Q only supported int8 and uint8. diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 639a49cb43a4f..e14a8d6b87fb0 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -464,9 +464,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Cast); // TODO(adrianlizarraga): Implement support for int4 and uint4. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Cast, 21, + 22, + KernelDefBuilder() + .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) + .MayInplace(0, 0), // allocation planner will check input and output sizes match before inplacing + Cast); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Implement support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Cast, + 23, KernelDefBuilder() .TypeConstraint("T1", BuildKernelDefConstraintsFromTypeList()) .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()) diff --git a/onnxruntime/core/providers/cpu/tensor/identity_op.cc b/onnxruntime/core/providers/cpu/tensor/identity_op.cc index 5ccd99f94a581..ff032108e109c 100644 --- a/onnxruntime/core/providers/cpu/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/identity_op.cc @@ -58,9 +58,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( IdentityOp); // TODO(liqunfu): Opset 21 supported int4 and uint4 types. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Identity, 21, + 22, + KernelDefBuilder().TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypesIRv9()).Alias(0, 0), + IdentityOp); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Identity, + 23, KernelDefBuilder().TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorAndOptionalTypesIRv9()).Alias(0, 0), IdentityOp); diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index dc590ab8422a7..e9f4fe9782fd1 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -107,6 +107,20 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( uint8_t, bool); +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES( + kCpuExecutionProvider, kOnnxDomain, Pad, 23, Input, 0, + float, + double, + int32_t, + int64_t, + uint32_t, + uint64_t, + int8_t, + uint8_t, + bool); + ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES( kCpuExecutionProvider, kOnnxDomain, Pad, 11, Input, 0, int32_t, int64_t); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES( @@ -117,6 +131,8 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES( kCpuExecutionProvider, kOnnxDomain, Pad, 19, Input, 0, int32_t, int64_t); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES( kCpuExecutionProvider, kOnnxDomain, Pad, 21, Input, 0, int32_t, int64_t); +ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES( + kCpuExecutionProvider, kOnnxDomain, Pad, 23, Input, 0, int32_t, int64_t); } // namespace op_kernel_type_control using EnabledPad2Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( @@ -131,6 +147,8 @@ using EnabledPad19Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, Pad, 19, Input, 0); using EnabledPad21Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, Pad, 21, Input, 0); +using EnabledPad23Types = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Pad, 23, Input, 0); using AllEnabledPadTypes = utils::TypeSetUnion< @@ -185,15 +203,24 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( BuildKernelDefConstraintsFromTypeList()), Pad); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Pad, - 21, + 21, 22, KernelDefBuilder() .TypeConstraint( "T", BuildKernelDefConstraintsFromTypeList()), Pad); +ONNX_CPU_OPERATOR_KERNEL( + Pad, + 23, + KernelDefBuilder() + .TypeConstraint( + "T", + BuildKernelDefConstraintsFromTypeList()), + Pad); + using PadsVector = PadBase::PadsVector; Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { diff --git a/onnxruntime/core/providers/cpu/tensor/reshape.cc b/onnxruntime/core/providers/cpu/tensor/reshape.cc index 3038213bfe577..805b0d2c6b70e 100644 --- a/onnxruntime/core/providers/cpu/tensor/reshape.cc +++ b/onnxruntime/core/providers/cpu/tensor/reshape.cc @@ -53,9 +53,21 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for int4 and uint4. // TODO(adrianlizarraga): Implement int4 and uint4 support. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Reshape, 21, + 22, + KernelDefBuilder() + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypesIRv9()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()), + Reshape); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Reshape, + 23, KernelDefBuilder() .Alias(0, 0) .TypeConstraint("T", DataTypeImpl::AllTensorTypesIRv9()) diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.cc b/onnxruntime/core/providers/cpu/tensor/shape_op.cc index 91d9e4581e788..0cedb21ca76bb 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.cc @@ -32,9 +32,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for int4 and uint4. // TODO(adrianlizarraga): Implement int4 and uint4 support. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Shape, 21, + 22, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypesIRv9()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Implement float4e2m1 support. +ONNX_CPU_OPERATOR_KERNEL( + Shape, + 23, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypesIRv9()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); diff --git a/onnxruntime/core/providers/cpu/tensor/size.cc b/onnxruntime/core/providers/cpu/tensor/size.cc index a994845d58332..4ee889bd28fde 100644 --- a/onnxruntime/core/providers/cpu/tensor/size.cc +++ b/onnxruntime/core/providers/cpu/tensor/size.cc @@ -85,9 +85,31 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added the int4 and uint4 types. // TODO(adrianlizarraga): Implement support for int4 and uint4. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Size, 21, + 22, + KernelDefBuilder().TypeConstraint("T", + std::vector({DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()})) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Size); + +// Opset 23 added the float4e2m1 type. +// TODO(titaiwang): Implement support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Size, + 23, KernelDefBuilder().TypeConstraint("T", std::vector({DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/providers/cpu/tensor/squeeze.cc b/onnxruntime/core/providers/cpu/tensor/squeeze.cc index 5217786ca14c0..8e24434a6f9fb 100644 --- a/onnxruntime/core/providers/cpu/tensor/squeeze.cc +++ b/onnxruntime/core/providers/cpu/tensor/squeeze.cc @@ -36,11 +36,23 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. // TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Squeeze, 21, + 22, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .Alias(0, 0), + Squeeze); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Squeeze, + 23, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) .Alias(0, 0), Squeeze); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 5b904e85848d0..d4fea3c5a75c7 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -469,9 +469,18 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. // TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, and float8e5m2fnuz. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Transpose, 21, + 22, + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + Transpose); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Implement support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Transpose, + 23, KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc index 3e521bcc4cbfe..19eb57c150749 100644 --- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc @@ -39,9 +39,20 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( // Opset 21 added support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. // TODO(adrianlizarraga): Implement support for float8e4m3fnuz, float8e5m2, float8e5m2fnuz, int4 and uint4. -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Unsqueeze, 21, + 22, + KernelDefBuilder() + .Alias(0, 0) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + Unsqueeze); + +// Opset 23 added support for float4e2m1. +// TODO(titaiwang): Add support for float4e2m1. +ONNX_CPU_OPERATOR_KERNEL( + Unsqueeze, + 23, KernelDefBuilder() .Alias(0, 0) .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index ed8d6ea71aea4..c4520fe38cd2a 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -314,7 +314,7 @@ struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; } @@ -329,18 +329,26 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } - static bool GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** /*ep_metadata*/, - _Out_opt_ OrtKeyValuePairs** /*ep_options*/) { - const auto* factory = static_cast(this_ptr); - - if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(device) == 0x10de) { - return true; + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) { + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; } static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, @@ -385,7 +393,7 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase } OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete factory; + delete static_cast(factory); return nullptr; } } diff --git a/onnxruntime/core/providers/cuda/onnxruntime_providers_cuda.rc b/onnxruntime/core/providers/cuda/onnxruntime_providers_cuda.rc new file mode 100644 index 0000000000000..189238071c389 --- /dev/null +++ b/onnxruntime/core/providers/cuda/onnxruntime_providers_cuda.rc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file REQUIRES the following external definitions: +// FILE_NAME, VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE, and VER_STRING + +#include + +#if defined(DEBUG) || defined(_DEBUG) +#define VER_DEBUG VS_FF_DEBUG +#else +#define VER_DEBUG 0 +#endif + +// ----------------------------------------------------------------------------- + +VS_VERSION_INFO VERSIONINFO +FILEVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +PRODUCTVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +FILEFLAGSMASK VS_FFI_FILEFLAGSMASK +FILEFLAGS VER_DEBUG +FILEOS VOS__WINDOWS32 +FILETYPE VFT_DLL +FILESUBTYPE VFT2_UNKNOWN + +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904E4" + BEGIN + VALUE "CompanyName", "Microsoft Corporation" + VALUE "FileDescription", "ONNX Runtime CUDA Provider" + VALUE "FileVersion", VER_STRING + VALUE "InternalName", "ONNX Runtime CUDA Provider" + VALUE "LegalCopyright", "\251 Microsoft Corporation. All rights reserved." + VALUE "OriginalFilename", FILE_NAME + VALUE "ProductName", "Microsoft\256 Windows\256 Operating System" + VALUE "ProductVersion", VER_STRING + END + END + + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x409, 1252 + END +END diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h index ca59c5f11ed1b..5d2b1a7e2b6fd 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h @@ -72,7 +72,7 @@ class BaseOpBuilder : public IOpBuilder { const OpSupportCheckParams& params) const; virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } - virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 21; } + virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 23; } // Check if this node_unit's type is supported // SingleNode type NodeUnit is supported diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 42f8f9fe8a62c..25c130a849793 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -281,7 +281,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vector>>& profile_opt_shapes, ShapeRangesMap& input_explicit_shape_ranges) { if (trt_profiles.size() == 0) { - LOGS_DEFAULT(WARNING) << "[Nv EP] Number of optimization profiles should be greater than 0, but it's 0."; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Number of optimization profiles should be greater than 0, but it's 0."; return false; } @@ -295,8 +295,8 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorgetDimensions(); @@ -309,7 +309,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] shape size of this shape tensor is " << shape_size; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] shape size of this shape tensor is " << shape_size; for (int j = 0; j < shape_size; j++) { auto min_value = profile_min_shapes[input_name][i][j]; @@ -318,9 +318,9 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(min_value); shapes_max[j] = static_cast(max_value); shapes_opt[j] = static_cast(opt_value); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] shapes_min.d[" << j << "] is " << shapes_min[j]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] shapes_max.d[" << j << "] is " << shapes_max[j]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] shapes_opt.d[" << j << "] is " << shapes_opt[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] shapes_min.d[" << j << "] is " << shapes_min[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] shapes_max.d[" << j << "] is " << shapes_max[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] shapes_opt.d[" << j << "] is " << shapes_opt[j]; if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { std::vector> profile_vector(trt_profiles.size()); @@ -342,7 +342,7 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(min_value); dims_max.d[j] = static_cast(max_value); dims_opt.d[j] = static_cast(opt_value); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] dims_min.d[" << j << "] is " << dims_min.d[j]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] dims_max.d[" << j << "] is " << dims_max.d[j]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] dims_opt.d[" << j << "] is " << dims_opt.d[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] dims_min.d[" << j << "] is " << dims_min.d[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] dims_max.d[" << j << "] is " << dims_max.d[j]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] dims_opt.d[" << j << "] is " << dims_opt.d[j]; if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { std::vector> profile_vector(trt_profiles.size()); @@ -933,7 +933,7 @@ NvExecutionProvider::PerThreadContext::~PerThreadContext() { bool NvExecutionProvider::PerThreadContext::CompareProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges) { if (shape_ranges.size() > 0) { if (input_shape_ranges_[fused_node] != shape_ranges) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] The shape ranges maintained by the PerThreadContext is different from the shape ranges maintained by TRT EP. \ + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] The shape ranges maintained by the PerThreadContext is different from the shape ranges maintained by TRT EP. \ This means the engine is updated and will need to update the execution context as well."; return true; } @@ -1068,53 +1068,95 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) } }; - // Get environment variables - if (info.has_trt_options) { - max_partition_iterations_ = info.max_partition_iterations; - min_subgraph_size_ = info.min_subgraph_size; - max_workspace_size_ = info.max_workspace_size; - dump_subgraphs_ = info.dump_subgraphs; - weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; - onnx_model_folder_path_ = info.onnx_model_folder_path; - onnx_model_bytestream_ = info.onnx_bytestream; - onnx_model_bytestream_size_ = info.onnx_bytestream_size; - if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) || - (onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) { - ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "When providing either 'trt_onnx_bytestream_size' or " - "'trt_onnx_bytestream' both have to be provided")); - } - detailed_build_log_ = info.detailed_build_log; - dump_ep_context_model_ = info.dump_ep_context_model; - ep_context_file_path_ = info.ep_context_file_path; - ep_context_embed_mode_ = info.ep_context_embed_mode; - enable_engine_cache_for_ep_context_model(); - cache_prefix_ = info.engine_cache_prefix; - // use a more global cache if given - engine_decryption_enable_ = info.engine_decryption_enable; - if (engine_decryption_enable_) { - engine_decryption_lib_path_ = info.engine_decryption_lib_path; - } - force_sequential_engine_build_ = info.force_sequential_engine_build; - context_memory_sharing_enable_ = info.context_memory_sharing_enable; - sparsity_enable_ = info.sparsity_enable; - auxiliary_streams_ = info.auxiliary_streams; - profile_min_shapes = info.profile_min_shapes; - profile_max_shapes = info.profile_max_shapes; - profile_opt_shapes = info.profile_opt_shapes; - cuda_graph_enable_ = info.cuda_graph_enable; - op_types_to_exclude_ = info.op_types_to_exclude; - } else { - LOGS_DEFAULT(INFO) << "[Nv EP] Options were not specified"; + max_partition_iterations_ = info.max_partition_iterations; + min_subgraph_size_ = info.min_subgraph_size; + max_workspace_size_ = info.max_workspace_size; + dump_subgraphs_ = info.dump_subgraphs; + weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; + onnx_model_folder_path_ = info.onnx_model_folder_path; + onnx_model_bytestream_ = info.onnx_bytestream; + onnx_model_bytestream_size_ = info.onnx_bytestream_size; + if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) || + (onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'trt_onnx_bytestream_size' or " + "'trt_onnx_bytestream' both have to be provided")); + } + detailed_build_log_ = info.detailed_build_log; + dump_ep_context_model_ = info.dump_ep_context_model; + ep_context_file_path_ = info.ep_context_file_path; + ep_context_embed_mode_ = info.ep_context_embed_mode; + enable_engine_cache_for_ep_context_model(); + cache_prefix_ = info.engine_cache_prefix; + // use a more global cache if given + engine_decryption_enable_ = info.engine_decryption_enable; + if (engine_decryption_enable_) { + engine_decryption_lib_path_ = info.engine_decryption_lib_path; + } + force_sequential_engine_build_ = info.force_sequential_engine_build; + context_memory_sharing_enable_ = info.context_memory_sharing_enable; + sparsity_enable_ = info.sparsity_enable; + auxiliary_streams_ = info.auxiliary_streams; + profile_min_shapes = info.profile_min_shapes; + profile_max_shapes = info.profile_max_shapes; + profile_opt_shapes = info.profile_opt_shapes; + + /* + * Parse explicit min/max/opt profile shapes from provider options. + * + * The format of min/max/opt profile shapes is defined as below: + * "input1:dim1xdim2...,input2:dim1xdim2...,...,input1:dim3xdim4...,input2:dim3xdim4...,..." + * + * (Note: if multiple shapes with same input name are specified, TRT EP will consider them as multiple profiles. + * Please refer to ParserProfileShapes() for more details) + * + */ + bool status = true; + if (status) { + status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); + if (!status) { + profile_min_shapes_.clear(); + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + } } + if (status) { + status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); + if (!status) { + profile_max_shapes_.clear(); + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + } + } + + if (status) { + status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); + if (!status) { + profile_opt_shapes_.clear(); + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + } + } + + if (status) { + status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (!status) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + profile_min_shapes_.clear(); + profile_max_shapes_.clear(); + profile_opt_shapes_.clear(); + } + } + + cuda_graph_enable_ = info.cuda_graph_enable; + op_types_to_exclude_ = info.op_types_to_exclude; + // Validate setting if (max_partition_iterations_ <= 0) { - // LOGS_DEFAULT(WARNING) << "[Nv EP] TensorRT option nv_max_partition_iterations must be a positive integer value. Set it to 1000"; + // LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] TensorRT option nv_max_partition_iterations must be a positive integer value. Set it to 1000"; max_partition_iterations_ = 1000; } if (min_subgraph_size_ <= 0) { - // LOGS_DEFAULT(WARNING) << "[Nv EP] TensorRT option nv_min_subgraph_size must be a positive integer value. Set it to 1"; + // LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] TensorRT option nv_min_subgraph_size must be a positive integer value. Set it to 1"; min_subgraph_size_ = 1; } @@ -1181,10 +1223,10 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) trt_version_ = getInferLibVersion(); CUDA_CALL_THROW(cudaRuntimeGetVersion(&cuda_version_)); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] TensorRT version is " << trt_version_; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] CUDA version is " << cuda_version_; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] TensorRT version is " << trt_version_; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] CUDA version is " << cuda_version_; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Nv provider options: " + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Nv provider options: " << "device_id: " << device_id_ << ", nv_max_partition_iterations: " << max_partition_iterations_ << ", nv_min_subgraph_size: " << min_subgraph_size_ @@ -1311,15 +1353,9 @@ nvinfer1::IBuilder* NvExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) } void NvExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - std::string extra_plugin_lib_paths{""}; - if (info_.has_trt_options) { - if (!info_.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info_.extra_plugin_lib_paths; - } - } - auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); + auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, info_.extra_plugin_lib_paths); if (status != Status::OK()) { - LOGS_DEFAULT(WARNING) << "[Nv EP] Failed to get TRT plugins from TRT plugin registration."; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to get TRT plugins from TRT plugin registration."; } } @@ -1498,7 +1534,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra auto meta_def = IndexedSubGraph_MetaDef::Create(); const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; meta_def->name() = "TRTKernel_" + graph_type + "_" + graph.Name() + "_" + subgraph_id; - LOGS_DEFAULT(INFO) << "[Nv EP] TensorRT subgraph MetaDef name " + meta_def->name(); + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] TensorRT subgraph MetaDef name " + meta_def->name(); // Assign inputs and outputs to subgraph's meta_def for (const auto& input : inputs) { @@ -1619,7 +1655,7 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // Only if the newly built graph has control flow op as well as it has parent node, // it needs to handle outer scope values before calling graph.Resolve(). if (has_control_flow_op && graph.ParentNode()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Handle outer scope values for the subgraph " << graph_build.Name(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Handle outer scope values for the subgraph " << graph_build.Name(); BuildSubGraphContext(graph_build); SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); SetAllGraphInputs(graph_build); @@ -2005,9 +2041,9 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, } SubGraphCollection_t consolidated_supported_nodes_vector = {{nodes_vector, true}}; if (DetectTensorRTGraphCycles(consolidated_supported_nodes_vector, graph, model_hash, false)) { - LOGS_DEFAULT(INFO) << "[Nv EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] TensorRT nodes are not consolidated because graph will have cycles after consolidation"; } else { - LOGS_DEFAULT(INFO) << "[Nv EP] TensorRT nodes are consolidated into one subgraph"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] TensorRT nodes are consolidated into one subgraph"; supported_nodes_vector = consolidated_supported_nodes_vector; } } @@ -2072,7 +2108,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, } } } - LOGS_DEFAULT(INFO) << "[Nv EP] Whole graph will run on Nv execution provider"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Whole graph will run on Nv execution provider"; // The context map is only used during EP compile time, release it to save memory space. subgraph_context_map_.clear(); @@ -2092,11 +2128,11 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, const size_t number_of_subgraphs = supported_nodes_vector.size(); if (number_of_trt_nodes == 0) { - LOGS_DEFAULT(WARNING) << "[Nv EP] No graph will run on Nv execution provider"; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] No graph will run on Nv execution provider"; } else if (number_of_trt_nodes == number_of_ort_nodes) { - LOGS_DEFAULT(INFO) << "[Nv EP] Whole graph will run on Nv execution provider"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Whole graph will run on Nv execution provider"; } else { - LOGS_DEFAULT(INFO) << "[Nv EP] Graph is partitioned and number of subgraphs running on Nv executio provider is " << number_of_subgraphs; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Graph is partitioned and number of subgraphs running on Nv executio provider is " << number_of_subgraphs; } // The context map is only used during EP compile time, release it to save memory space. @@ -2154,20 +2190,20 @@ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); if (refit_from_file) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Refitting from file on disk: " << onnx_model_path.string(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); } } else { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Refitting from byte array"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); } } if (refitter->refitCudaEngine()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Successfully refitted the weight-stripped engine."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully refitted the weight-stripped engine."; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); @@ -2179,7 +2215,7 @@ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Serialize the refitted engine to " << refitted_engine_cache; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Serialize the refitted engine to " << refitted_engine_cache; } return Status::OK(); } @@ -2342,7 +2378,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr has_dynamic_shape |= tensor_is_dynamic(input); } if (has_dynamic_shape) { - LOGS_DEFAULT(WARNING) << "[Nv EP] No explicit optimization profile was specified. " + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] No explicit optimization profile was specified. " "We will assume a single profile with fully dynamic range. " "This feature is experimental and may change in the future." "If you plan to use this model as fixed shape we recommend using a free dimension override: " @@ -2365,7 +2401,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (has_explicit_profile && tensor_has_profile) { apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); } else { - LOGS_DEFAULT(INFO) << "[Nv EP] Creating implicit profile for tensor " << input_name; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Creating implicit profile for tensor " << input_name; profile_min_shapes_[input_name] = std::vector>{{}}; profile_min_shapes_[input_name][0].resize(dims.nbDims); profile_opt_shapes_[input_name] = std::vector>{{}}; @@ -2422,20 +2458,20 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // enable sparse weights if (sparsity_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Sparse weights are allowed"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Sparse weights are allowed"; } // limit auxiliary streams if (auxiliary_streams_ >= 0) { trt_config->setMaxAuxStreams(auxiliary_streams_); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Auxiliary streams are se to " << auxiliary_streams_; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Auxiliary streams are se to " << auxiliary_streams_; } if (weight_stripped_engine_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] STRIP_PLAN is enabled"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] STRIP_PLAN is enabled"; trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] REFIT_IDENTICAL is enabled"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] REFIT_IDENTICAL is enabled"; } // Build TRT engine (if needed) and load TRT engine if: @@ -2518,7 +2554,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } if (weight_stripped_engine_refit_) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Refit engine from main ONNX file after engine build"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refit engine from main ONNX file after engine build"; char* onnx = string_buf.data(); size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 5559e2e791d40..0806ae3638036 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -58,9 +58,9 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& while (std::getline(extra_plugin_libs, lib, ';')) { auto status = LoadDynamicLibrary(ToPathString(lib)); if (status == Status::OK()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Successfully load " << lib; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully load " << lib; } else { - LOGS_DEFAULT(WARNING) << "[Nv EP]" << status.ToString(); + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP]" << status.ToString(); } } is_loaded = true; @@ -68,7 +68,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& try { // Get all registered TRT plugins from registry - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Getting all registered TRT plugins from TRT plugin registry ..."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); void* library_handle = nullptr; const auto& env = onnxruntime::GetDefaultEnv(); @@ -79,14 +79,14 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); dyn_initLibNvInferPlugins(&trt_logger, ""); - LOGS_DEFAULT(INFO) << "[Nv EP] Default plugins successfully loaded."; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Default plugins successfully loaded."; #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated #endif } catch (const std::exception&) { - LOGS_DEFAULT(INFO) << "[Nv EP] Default plugin library is not on the path and is therefore ignored"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Default plugin library is not on the path and is therefore ignored"; } try { int num_plugin_creator = 0; @@ -96,7 +96,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& for (int i = 0; i < num_plugin_creator; i++) { auto plugin_creator = plugin_creators[i]; std::string plugin_name(plugin_creator->getPluginName()); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion(); // plugin has different versions and we only register once if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) { @@ -116,7 +116,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& custom_op_domain->domain_ = "trt.plugins"; domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { - LOGS_DEFAULT(WARNING) << "[Nv EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } return Status::OK(); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc index 5373b6fd08afc..cd50f1e6b2d48 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_helper.cc @@ -169,31 +169,31 @@ void NvExecutionProvider::SetGraphOuterScopeValuesAndInputs(Graph& graph_build, } std::string unique_graph_name = GetUniqueGraphName(*top_level_graph); if (subgraph_context_map_.find(unique_graph_name) == subgraph_context_map_.end()) { - LOGS_DEFAULT(ERROR) << "[Nv EP] Can't find top-level graph context. \ + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Can't find top-level graph context. \ Please check BuildSubGraphContext() has built the graph context correctly."; return; } SubGraphContext* context = subgraph_context_map_.at(unique_graph_name).get(); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Subgraph name is " << graph_build.Name(); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Its parent node is " << graph.ParentNode()->Name(); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Its parent node's implicit inputs:"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Subgraph name is " << graph_build.Name(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Its parent node is " << graph.ParentNode()->Name(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Its parent node's implicit inputs:"; // Iterate all the implicit inputs to set outer scope value for the newly built subgraph for (const auto& input : graph.ParentNode()->ImplicitInputDefs()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name(); // The node arg in parent node's implicit inputs could be used for parent node's other subgraph, for example // "If" op has two subgraphs. So we need to make sure that the node arg is used in current subgraph only. // (GetNodeArg searches for specific node arg in all node args in the graph) if (graph_build.GetNodeArg(input->Name())) { graph_build.AddOuterScopeNodeArg(input->Name()); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name() << " is used in this subgraph"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name() << " is used in this subgraph"; if (context && (context->manually_added_graph_inputs.find(input->Name()) != context->manually_added_graph_inputs.end())) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << input->Name() << " is already been added as an explicit input to graph"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << input->Name() << " is already been added as an explicit input to graph"; continue; } @@ -213,7 +213,7 @@ void NvExecutionProvider::SetGraphOuterScopeValuesAndInputs(Graph& graph_build, type_proto->copy_from(input->TypeAsProto()); auto& n_input = top_level_graph->GetOrCreateNodeArg(name, type_proto.get()); context->manually_added_graph_inputs[n_input.Name()] = &n_input; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] \t" << n_input.Name() << " is added as an explicit input into the newly built graph"; } } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index 05e5f7659efac..f5ba66746c3c4 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -4,13 +4,15 @@ #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/common/make_string.h" #include "core/common/parse_string.h" #include "core/framework/provider_options_utils.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { -NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { +NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options, + const ConfigOptions& session_options) { NvExecutionProviderInfo info{}; void* user_compute_stream = nullptr; void* onnx_bytestream = nullptr; @@ -58,6 +60,25 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); info.onnx_bytestream = onnx_bytestream; + + // EP context settings + const auto embed_enable = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + if (embed_enable == "0") { + info.dump_ep_context_model = false; + } else if (embed_enable == "1") { + info.dump_ep_context_model = true; + } else { + ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1"); + } + info.ep_context_file_path = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + + const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1")); + if (0 <= embed_mode || embed_mode < 2) { + info.ep_context_embed_mode = embed_mode; + } else { + ORT_THROW("Invalid ", kOrtSessionOptionEpContextEmbedMode, " must 0 or 1"); + } + return info; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index c3c4dba1ed982..626039e5ef7c8 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -8,8 +8,9 @@ #include "core/framework/ortdevice.h" #include "core/framework/provider_options.h" #include "core/framework/framework_provider_common.h" -#include "core/session/onnxruntime_c_api.h" #include "core/framework/library_handles.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/providers/shared_library/provider_api.h" #define TRT_DEFAULT_OPTIMIZER_LEVEL 3 @@ -19,18 +20,10 @@ struct NvExecutionProviderInfo { int device_id{0}; bool has_user_compute_stream{false}; void* user_compute_stream{nullptr}; - bool has_trt_options{false}; int max_partition_iterations{1000}; int min_subgraph_size{1}; size_t max_workspace_size{0}; - bool fp16_enable{false}; - bool int8_enable{false}; - std::string int8_calibration_table_name{""}; - bool int8_use_native_calibration_table{false}; - bool dla_enable{false}; - int dla_core{0}; bool dump_subgraphs{false}; - bool engine_cache_enable{false}; std::string engine_cache_path{""}; bool weight_stripped_engine_enable{false}; std::string onnx_model_folder_path{""}; @@ -40,16 +33,10 @@ struct NvExecutionProviderInfo { std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; bool context_memory_sharing_enable{false}; - bool layer_norm_fp32_fallback{false}; - bool timing_cache_enable{false}; std::string timing_cache_path{""}; - bool force_timing_cache{false}; bool detailed_build_log{false}; - bool build_heuristics_enable{false}; bool sparsity_enable{false}; - int builder_optimization_level{3}; int auxiliary_streams{-1}; - std::string tactic_sources{""}; std::string extra_plugin_lib_paths{""}; std::string profile_min_shapes{""}; std::string profile_max_shapes{""}; @@ -59,10 +46,10 @@ struct NvExecutionProviderInfo { std::string ep_context_file_path{""}; int ep_context_embed_mode{0}; std::string engine_cache_prefix{""}; - bool engine_hw_compatible{false}; std::string op_types_to_exclude{""}; - static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); + static NvExecutionProviderInfo FromProviderOptions(const ProviderOptions& options, + const ConfigOptions& session_options); static ProviderOptions ToProviderOptions(const NvExecutionProviderInfo& info); std::vector custom_op_domain_list; }; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index 169127f222949..046010deedf62 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -153,22 +153,22 @@ std::unordered_map>>>& shape_ranges) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] In SerializeProfileV2()"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In SerializeProfileV2()"; // Serialize profile flexbuffers::Builder builder; auto tensor_map_start = builder.StartMap(); for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors - LOGS_DEFAULT(VERBOSE) << "[Nv EP] input tensor is '" << tensor_it->first.c_str() << "'"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] input tensor is '" << tensor_it->first.c_str() << "'"; builder.TypedVector(tensor_it->first.c_str(), [&] { for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { size_t num_profiles = dim_it->second.size(); for (size_t i = 0; i < num_profiles; i++) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] profile #" << i << ", dim is " << dim_it->first; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] profile #" << i << ", dim is " << dim_it->first; builder.Int(dim_it->first); builder.Int(dim_it->second[i][0]); builder.Int(dim_it->second[i][1]); builder.Int(dim_it->second[i][2]); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; } } }); @@ -233,7 +233,7 @@ void SerializeProfileV2(const std::string& file_name, std::unordered_map>>> DeserializeProfileV2(std::ifstream& infile) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] In DeserializeProfileV2()"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] In DeserializeProfileV2()"; // Load flexbuffer infile.seekg(0, std::ios::end); size_t length = infile.tellg(); @@ -248,7 +248,7 @@ std::unordered_map>> inner_map; std::vector> profile_vector; @@ -265,7 +265,7 @@ std::unordered_map>>& profile_opt_shapes) { std::ifstream profile_file(file_name, std::ios::binary | std::ios::in); if (!profile_file) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] " << file_name << " doesn't exist."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " << file_name << " doesn't exist."; return true; } @@ -313,7 +313,7 @@ bool CompareProfiles(const std::string& file_name, // Check number of dynamic shape inputs if (profile_min_shapes.size() != shape_ranges.size()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Numbers of dynamic shape inputs are not the same."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Numbers of dynamic shape inputs are not the same."; return true; } @@ -321,7 +321,7 @@ bool CompareProfiles(const std::string& file_name, for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors auto tensor_name = tensor_it->first; if (profile_min_shapes.find(tensor_name) == profile_min_shapes.end()) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; return true; } @@ -330,35 +330,35 @@ bool CompareProfiles(const std::string& file_name, auto num_profiles = GetNumProfiles(profile_min_shapes); if (dim_it->second.size() != static_cast(num_profiles)) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Numbers of profiles are not the same."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Numbers of profiles are not the same."; return true; } for (size_t i = 0; i < dim_it->second.size(); i++) { // iterate (multiple) profile(s) auto shape_values = dim_it->second[i]; if (dim > (profile_min_shapes[tensor_name][i].size() - 1)) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; return true; } - LOGS_DEFAULT(VERBOSE) << "[Nv EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; if (profile_min_shapes[tensor_name][i][dim] != shape_values[0]) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } - LOGS_DEFAULT(VERBOSE) << "[Nv EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; if (profile_max_shapes[tensor_name][i][dim] != shape_values[1]) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } - LOGS_DEFAULT(VERBOSE) << "[Nv EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; - LOGS_DEFAULT(VERBOSE) << "[Nv EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; if (profile_opt_shapes[tensor_name][i][dim] != shape_values[2]) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; return true; } } @@ -461,7 +461,7 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version if (main_graph.ModelPath().has_filename()) { std::string model_name = PathToUTF8String(main_graph.ModelPath().filename()); - LOGS_DEFAULT(INFO) << "[Nv EP] Model name is " << model_name; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Model name is " << model_name; // Ensure enough characters are hashed in case model names are too short const size_t model_name_length = model_name.size(); constexpr size_t hash_string_length = 500; @@ -471,7 +471,7 @@ HashValue TRTGenerateId(const GraphViewer& graph_viewer, std::string trt_version } hash_str(repeat_model_name); } else { - LOGS_DEFAULT(INFO) << "[Nv EP] Model path is empty"; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Model path is empty"; } // fingerprint current graph by hashing graph inputs @@ -567,7 +567,7 @@ bool MakeInputNameShapePair(std::string pair_string, std::pair& domain_list, const std::string extra_plugin_lib_paths) override { common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); if (!status.IsOK()) { - return CreateStatus(ORT_FAIL, "[Nv EP] Can't create custom ops for TRT plugins."); + return CreateStatus(ORT_FAIL, "[NvTensorRTRTX EP] Can't create custom ops for TRT plugins."); } return nullptr; } @@ -79,7 +79,7 @@ std::unique_ptr NvProviderFactory::CreateProvider(const OrtS provider_options[key.substr(key_prefix.size())] = value; } } - NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(provider_options); + NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(provider_options, config_options); auto ep = std::make_unique(info); ep->SetLogger(reinterpret_cast(&session_logger)); @@ -91,14 +91,26 @@ struct Nv_Provider : Provider { std::shared_ptr CreateExecutionProviderFactory(int device_id) override { NvExecutionProviderInfo info; info.device_id = device_id; - info.has_trt_options = false; return std::make_shared(info); } - std::shared_ptr CreateExecutionProviderFactory(const void* options) { - const ProviderOptions* provider_options = reinterpret_cast(options); - NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(*provider_options); + std::shared_ptr CreateExecutionProviderFactory(const void* param) { + if (param == nullptr) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Passed NULL options to CreateExecutionProviderFactory()"; + return nullptr; + } + + std::array pointers_array = *reinterpret_cast*>(param); + const ProviderOptions* provider_options = reinterpret_cast(pointers_array[0]); + const ConfigOptions* config_options = reinterpret_cast(pointers_array[1]); + + if (provider_options == nullptr) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()"; + return nullptr; + } + + NvExecutionProviderInfo info = onnxruntime::NvExecutionProviderInfo::FromProviderOptions(*provider_options, *config_options); return std::make_shared(info); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h index 7eeb6cce4fa03..616f5f1fbe754 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h @@ -9,9 +9,12 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; + // defined in provider_bridge_ort.cc struct NvProviderFactoryCreator { static std::shared_ptr Create(int device_id); - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options_map, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 4f84e853f999c..25decd8f2ce8f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -213,7 +213,7 @@ void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, const std::string& ctx_model_path) { std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Dumped " + ctx_model_path; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Dumped " + ctx_model_path; } bool IsAbsolutePath(const std::string& path_string) { @@ -285,7 +285,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s(); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not deserialize engine from binary data"); @@ -324,7 +324,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph // The engine cache and context model (current model) should be in the same directory std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); auto engine_cache_path = ctx_model_dir.append(cache_path); - LOGS_DEFAULT(VERBOSE) << "[Nv EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled if (!weight_stripped_engine_refit_) { @@ -335,7 +335,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph if (weight_stripped_engine_refit_) { const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); if (std::filesystem::exists(refitted_engine_cache_path)) { - LOGS_DEFAULT(VERBOSE) << "[Nv EP] " + refitted_engine_cache_path.string() + " exists."; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " + refitted_engine_cache_path.string() + " exists."; engine_cache_path = refitted_engine_cache_path.string(); weight_stripped_engine_refit_ = false; } @@ -358,7 +358,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not deserialize engine from cache: " + engine_cache_path.string()); } - LOGS_DEFAULT(VERBOSE) << "[Nv EP] DeSerialized " + engine_cache_path.string(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] DeSerialized " + engine_cache_path.string(); if (weight_stripped_engine_refit_) { const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); @@ -394,14 +394,14 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s(); // Verify if engine was compiled with ampere+ hardware compatibility enabled if (model_compute_capability == "80+") { - LOGS_DEFAULT(WARNING) << "[Nv EP] Engine is compatible to all Ampere+ GPU (except Jetson)"; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible to all Ampere+ GPU (except Jetson)"; if (std::stoi(compute_capability_) < 80) { - LOGS_DEFAULT(WARNING) << "[Nv EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] However, this GPU doesn't match. The compute capability of the GPU: " << compute_capability_; } } else if (model_compute_capability != compute_capability_) { - LOGS_DEFAULT(WARNING) << "[Nv EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; - LOGS_DEFAULT(WARNING) << "[Nv EP] The compute capability of the engine: " << model_compute_capability; - LOGS_DEFAULT(WARNING) << "[Nv EP] The compute capability of the GPU: " << compute_capability_; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The compute capability of the engine: " << model_compute_capability; + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The compute capability of the GPU: " << compute_capability_; } } diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 13f09b9d9acdb..9ef7e4b86db5f 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -324,7 +324,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { -#ifndef RELEASE +#ifndef RELEASE if (openvino_ep::backend_utils::IsDebugEnabled()) { auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename(); diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 9f33df54a4330..76812f9e83be6 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -54,7 +54,7 @@ #include "core/providers/js/js_provider_factory_creator.h" #endif -#if defined(USE_OPENVINO) +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) #include "core/providers/openvino/openvino_provider_factory_creator.h" #endif @@ -66,7 +66,7 @@ #include "core/providers/rocm/rocm_provider_factory_creator.h" #endif -#if defined(USE_QNN) +#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) #include "core/providers/qnn/qnn_provider_factory_creator.h" #endif @@ -74,15 +74,15 @@ #include "core/providers/snpe/snpe_provider_factory_creator.h" #endif -#if defined(USE_TENSORRT) +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" #endif -#if defined(USE_NV) +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) #include "core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h" #endif -#if defined(USE_VITISAI) +#if defined(USE_VITISAI) || defined(USE_VITISAI_PROVIDER_INTERFACE) #include "core/providers/vitisai/vitisai_provider_factory_creator.h" #endif @@ -105,7 +105,3 @@ #if defined(USE_AZURE) #include "core/providers/azure/azure_provider_factory_creator.h" #endif - -#if defined(USE_NV) -#include "core/providers/nv_tensorrt_rtx/nv_provider_factory_creator.h" -#endif diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 77579dfc793ee..efb4afcb88c85 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -43,6 +43,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Elu", *this); CreateSimpleOpBuilder("Round", *this); CreateSimpleOpBuilder("Where", *this); + CreateSimpleOpBuilder("ScatterND", *this); CreateSimpleOpBuilder("Sigmoid", *this); CreateSimpleOpBuilder("Sin", *this); CreateSimpleOpBuilder("Sqrt", *this); @@ -133,6 +134,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateResizeOpBuilder("Resize", *this); } + { + CreateUpsampleOpBuilder("Upsample", *this); + } + { CreateTopKOpBuilder("TopK", *this); } @@ -169,6 +174,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateExpandOpBuilder("Expand", *this); } + { + CreateEinsumOpBuilder("Einsum", *this); + } + { CreateMatMulOpBuilder("MatMul", *this); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index e11eae84341fe..aa1039f857f8e 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -75,6 +75,8 @@ void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateUpsampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + void CreateTopKOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -98,5 +100,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 02d2bf22b8144..d7432f35e61cf 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -204,7 +204,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, std::string output_name; }; std::vector cast_node_info_vec; - + auto mem_type = QNN_TENSORMEMTYPE_RAW; + if (true == qnn_model_wrapper.GetModelSettings().htp_shared_memory) { + mem_type = QNN_TENSORMEMTYPE_MEMHANDLE; + } const auto output_count = GetOutputCountQnnRequired(node_unit); for (size_t output_i = 0; output_i < output_count; ++output_i) { const auto& output_name = outputs[output_i].node_arg.Name(); @@ -255,7 +258,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, QNN_TENSOR_TYPE_NATIVE, supported_qnn_data_type, output_info.quant_param.Copy(), - std::move(cast_output_shape)); + std::move(cast_output_shape), {}, + mem_type); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); output_names.push_back(cast_input_name); cast_node_info_vec.push_back({cast_node_name, cast_input_name, output_name}); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 37060fcd9ba93..5474db0590f92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -155,6 +155,7 @@ class BaseOpBuilder : public IOpBuilder { {"ReduceSum", QNN_OP_REDUCE_SUM}, {"Round", QNN_OP_ELEMENT_WISE_ROUND}, {"Where", QNN_OP_ELEMENT_WISE_SELECT}, + {"ScatterND", QNN_OP_SCATTER_ND}, {"Sigmoid", QNN_OP_SIGMOID}, {"Sin", QNN_OP_ELEMENT_WISE_SIN}, {"Slice", QNN_OP_STRIDED_SLICE}, @@ -192,6 +193,7 @@ class BaseOpBuilder : public IOpBuilder { {"Reshape", QNN_OP_RESHAPE}, {"Resize", QNN_OP_RESIZE}, + {"Upsample", QNN_OP_RESIZE}, {"Flatten", QNN_OP_RESHAPE}, {"Squeeze", QNN_OP_RESHAPE}, {"Unsqueeze", QNN_OP_RESHAPE}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 193b507083360..a1a658d5d963c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -94,13 +94,13 @@ Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N if (node_unit.Inputs().size() > 1) { const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name(); if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max."); } } if (node_unit.Inputs().size() > 2) { const auto& max_input_name = node_unit.Inputs()[2].node_arg.Name(); if (!max_input_name.empty() && !qnn_model_wrapper.IsConstantInput(max_input_name)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max."); } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc new file mode 100644 index 0000000000000..9db0b5202dcd4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -0,0 +1,396 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/cpu/tensor/slice_helper.h" + +namespace { + +// Represented as a tuple of 3 strings . +// The equation string is expected to follow the format "term_1,term_2->result" +using Equation = std::tuple; + +/** + * @brief Parses an equation string into its components if it adheres to the expected format. + * + * @param equation_string The input equation string to parse. + * @return A std::optional containing a tuple of 3 strings (term_1, term_2, result) if the parsing is successful. + * Returns std::nullopt if the input string is invalid or does not conform to the expected format. + */ +std::optional ParseEquation(std::string_view equation_string) { + std::string equation(equation_string); + equation.erase(std::remove(equation.begin(), equation.end(), ' '), + equation.end()); + if (equation.empty()) { + return std::nullopt; + } + auto index_arrow = equation.find("->"); + if (index_arrow == std::string::npos) { + return std::nullopt; + } + const std::string lhs = equation.substr(0, index_arrow); + const std::string result = equation.substr(index_arrow + 2); + if (lhs.empty() || result.empty()) { + return std::nullopt; + } + auto index_comma = lhs.find(","); + if (index_comma == std::string::npos) { + return std::nullopt; + } + const std::string term_1 = lhs.substr(0, index_comma); + const std::string term_2 = lhs.substr(index_comma + 1); + if (term_1.empty() || term_2.empty()) { + return std::nullopt; + } + if (term_1.size() < 2) { + return std::nullopt; + } + if (term_1.size() != term_2.size()) { + return std::nullopt; + } + if (term_1.size() != result.size()) { + return std::nullopt; + } + if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + if (!std::all_of(term_2.begin(), term_2.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + if (!std::all_of(result.begin(), result.end(), [](unsigned char c) { return std::islower(c); })) { + return std::nullopt; + } + return std::make_tuple(term_1, term_2, result); +} + +bool IsEquationMatMul(const Equation& equation) { + // MatMul: e.g., "ij,jk->ik" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + for (size_t i = 0; i < num_dims; ++i) { + if (i >= num_dims - 2) { + continue; + } + if (!(term_1[i] == term_2[i] && term_1[i] == result[i])) { + return false; + } + } + char term_1_m = term_1[num_dims - 2]; + char term_2_k = term_2[num_dims - 2]; + char result_m = result[num_dims - 2]; + char term_1_k = term_1[num_dims - 1]; + char term_2_n = term_2[num_dims - 1]; + char result_n = result[num_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + return true; +} + +bool IsEquationMatMulTransposeY(const Equation& equation) { + // MatMul with 2nd input transposed: e.g., "id,jd->ij" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + for (size_t i = 0; i < num_dims; ++i) { + if (i >= num_dims - 2) { + continue; + } + if (!(term_1[i] == term_2[i] && term_1[i] == result[i])) { + return false; + } + } + char term_1_m = term_1[num_dims - 2]; + char term_2_k = term_2[num_dims - 2]; + char result_m = result[num_dims - 2]; + char term_1_k = term_1[num_dims - 1]; + char term_2_n = term_2[num_dims - 1]; + char result_n = result[num_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_n) { + return false; + } + if (term_2_k != result_n) { + return false; + } + return true; +} + +bool IsEquationMatMulTransposeAll(const Equation& equation) { + // MatMul transpose both inputs and output, e.g., "bchq,bkhc->bkhq", "bkhq,bchk->bchq" + const auto& [term_1, term_2, result] = equation; + const size_t num_dims = term_1.size(); + if (num_dims != 4) { + return false; + } + if (term_1[0] != term_2[0] || term_1[0] != result[0]) { + return false; + } + char term_1_m = term_1[num_dims - 1]; + char term_1_k = term_1[num_dims - 3]; + char term_2_k = term_2[num_dims - 1]; + char term_2_n = term_2[num_dims - 3]; + char result_m = result[num_dims - 1]; + char result_n = result[num_dims - 3]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + return true; +} + +/** + * @brief Sets the parameter tensor names for a MatMul op. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance that manages the QNN model. + * @param node_unit Reference to the NodeUnit representing the ONNX node for which the parameters are being set. + * @param transpose_in0 Boolean flag indicating whether the 1st input tensor should be transposed (default: false). + * @param transpose_in1 Boolean flag indicating whether the 2nd input tensor should be transposed (default: false). + * @return A vector of strings containing the names of the parameter tensors added to the QNN model. + */ +std::vector SetMatMulParamTensorNames( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + bool transpose_in0 = false, + bool transpose_in1 = false) { + std::vector param_tensor_names; + Qnn_Scalar_t scalar_params[2] = {QNN_SCALAR_INIT, QNN_SCALAR_INIT}; + scalar_params[0].dataType = QNN_DATATYPE_BOOL_8; + scalar_params[1].dataType = QNN_DATATYPE_BOOL_8; + scalar_params[0].bool8Value = static_cast(transpose_in0); + scalar_params[1].bool8Value = static_cast(transpose_in1); + onnxruntime::qnn::QnnParamWrapper transpose_in0_param( + node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, scalar_params[0]); + onnxruntime::qnn::QnnParamWrapper transpose_in1_param( + node_unit.Index(), node_unit.Name(), QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, scalar_params[1]); + param_tensor_names.push_back(transpose_in0_param.GetParamTensorName()); + param_tensor_names.push_back(transpose_in1_param.GetParamTensorName()); + qnn_model_wrapper->AddParamWrapper(std::move(transpose_in0_param)); + qnn_model_wrapper->AddParamWrapper(std::move(transpose_in1_param)); + return param_tensor_names; +} + +/** + * @brief Creates a MatMul operation with transposed inputs and output in a QNN model. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model. + * @param node_unit The NodeUnit representing the ONNX node to be converted. + * @param do_op_validation A boolean flag indicating whether to perform operation validation. + * @return Status indicating success or failure of the operation. + */ +Status CreateMatMulTransposeAll( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + std::vector&& input_names, + bool do_op_validation) { + onnxruntime::qnn::TensorInfo input_info0{}, input_info1{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], input_info0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], input_info1)); + std::vector input_shape0(input_info0.shape); + std::vector input_shape1(input_info1.shape); + std::swap(input_shape0[1], input_shape0[2]); + std::swap(input_shape1[1], input_shape1[2]); + const std::string input_transpos0 = input_names[0] + "_t0"; + const std::string input_transpos1 = input_names[1] + "_t1"; + const std::vector transpose_perm{0, 2, 1, 3}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/input_names[0], + /*output_name=*/input_transpos0, + /*input_shape=*/input_info0.shape, + /*transpose_perm=*/transpose_perm, + /*output_shape=*/input_shape0, + /*qnn_data_type=*/input_info0.qnn_data_type, + /*quantize_param=*/input_info0.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0]))); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/input_names[1], + /*output_name=*/input_transpos1, + /*input_shape=*/input_info1.shape, + /*transpose_perm=*/transpose_perm, + /*output_shape=*/input_shape1, + /*qnn_data_type=*/input_info1.qnn_data_type, + /*quantize_param=*/input_info1.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[1]))); + onnxruntime::qnn::TensorInfo matmul_output_info{}; + const auto& output = node_unit.Outputs()[0]; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(output, matmul_output_info)); + const std::string matmul_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_matmul"; + std::vector matmul_output_shape(matmul_output_info.shape); + std::swap(matmul_output_shape[1], matmul_output_shape[2]); + onnxruntime::qnn::QnnTensorWrapper matmul_output_wrapper( + matmul_output_name, QNN_TENSOR_TYPE_NATIVE, matmul_output_info.qnn_data_type, + matmul_output_info.quant_param.Copy(), std::vector(matmul_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(matmul_output_wrapper)), + node_unit.OpType() + " failed to add tensor."); + std::vector param_tensor_names = SetMatMulParamTensorNames( + qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/false); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(/*qnn_node_name=*/onnxruntime::qnn::utils::GetNodeName(node_unit), + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_MAT_MUL, + /*input_names=*/{input_transpos1, input_transpos0}, + /*output_names=*/{matmul_output_name}, + /*param_tensor_names=*/std::move(param_tensor_names), + /*do_op_validation=*/do_op_validation), + node_unit.OpType() + " failed to add node."); + std::vector transpose_output_shape(matmul_output_info.shape); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( + /*node_index=*/node_unit.Index(), + /*input_name=*/matmul_output_name, + /*output_name=*/output.node_arg.Name(), + /*input_shape=*/std::move(matmul_output_shape), + /*transpose_perm=*/transpose_perm, + /*output_shape=*/matmul_output_info.shape, + /*tensor_data_type=*/matmul_output_info.qnn_data_type, + /*quantize_param=*/matmul_output_info.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(output.node_arg.Name()), + /*is_for_output=*/qnn_model_wrapper->IsGraphOutput(output.node_arg.Name()))); + return Status::OK(); +} + +} // namespace + +namespace onnxruntime { +namespace qnn { + +class EinsumOpBuilder : public BaseOpBuilder { + public: + EinsumOpBuilder() : BaseOpBuilder("EinsumOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(EinsumOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; +}; + +Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + if (node_unit.Inputs().size() < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " requires at least 2 inputs."); + } + NodeAttrHelper node_helper{node_unit}; + const std::string equation = node_helper.Get("equation", std::string("")); + std::optional parsed_equation = ParseEquation(equation); + if (!parsed_equation.has_value()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + if (!IsEquationMatMul(parsed_equation.value()) && + !IsEquationMatMulTransposeY(parsed_equation.value()) && + !IsEquationMatMulTransposeAll(parsed_equation.value())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); +} + +Status EinsumOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); + return Status::OK(); +} + +Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + NodeAttrHelper node_helper(node_unit); + const std::string equation = node_helper.Get("equation", std::string("")); + std::optional parsed_equation = ParseEquation(equation); + if (IsEquationMatMul(parsed_equation.value())) { + std::vector param_tensor_names = SetMatMulParamTensorNames( + &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/false); + ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*param_tensor_names=*/std::move(param_tensor_names), + /*logger=*/logger, + /*do_op_validation=*/do_op_validation, + /*qnn_op_type=*/QNN_OP_MAT_MUL)); + } else if (IsEquationMatMulTransposeY(parsed_equation.value())) { + std::vector param_tensor_names = SetMatMulParamTensorNames( + &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/true); + ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*param_tensor_names=*/std::move(param_tensor_names), + /*logger=*/logger, + /*do_op_validation=*/do_op_validation, + /*qnn_op_type=*/QNN_OP_MAT_MUL)); + } else if (IsEquationMatMulTransposeAll(parsed_equation.value())) { + ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(&qnn_model_wrapper, node_unit, std::move(input_names), do_op_validation)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); + } + return Status::OK(); +} + +Status EinsumOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + + // Force the operator output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + +void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 229d86082f6dc..ab022df063c96 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -56,6 +56,13 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, padding_mode.c_str()); } + // To DO: Remove once QNN CPU supports ScatterND + const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); + if (op_type == "ScatterND") { + ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, + "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); + } + // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). // However, QNN's Min, Max, and Add operators must take in exactly two inputs. if (op_type == "Min" || op_type == "Max") { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc index 19e5ee298f5fb..bcf4df8186dd2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/slice_op_builder.cc @@ -46,7 +46,7 @@ Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const for (size_t i = 1; i < input_count; i++) { const auto& next_input = node_unit.Inputs()[i].node_arg.Name(); if (!qnn_model_wrapper.IsConstantInput(next_input)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic slice."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic slice."); } } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index e8acaf75143d8..8704420d98ead 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -39,12 +39,12 @@ constexpr int32_t GetDefaultAxisAttribute(int opset_version) { return opset_version < 13 ? 1 : -1; } -std::vector FlattenShapeFromAxis(std::vector& input_shape, int32_t axis) { +std::vector FlattenShapeFromAxis(const std::vector& input_shape, int32_t axis) { /* Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape will include an additional batch of size 1 as the first dimension. */ - assert(axis >= 0 && axis < input_shape.size()); + assert(axis >= 0 && static_cast(axis) < input_shape.size()); std::vector output_shape(input_shape.begin(), input_shape.begin() + axis); if (axis == 0) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index 555992ef00bfe..cba1faaa4fa2d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -42,7 +42,7 @@ Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::vector& input_names, bool do_op_validation) const { const auto& inputs = node_unit.Inputs(); - // QNN Tile only support 1 input, the 2nd input need to be initialier and set as Qnn node parameter + // QNN Tile only support 1 input, the 2nd input need to be initializer and set as Qnn node parameter if (do_op_validation) { auto& repeats_input_name = inputs[1].node_arg.Name(); ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(repeats_input_name), @@ -60,7 +60,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const logging::Logger& logger, bool do_op_validation) const { std::vector param_tensor_names; - // Already confirmed repeats input is initailizer in ProcessInputs() + // Already confirmed repeats input is initializer in ProcessInputs() const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name(); std::vector unpacked_tensor; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc new file mode 100644 index 0000000000000..48214f92b1a61 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/upsample_op_builder.cc @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +namespace onnxruntime { +namespace qnn { + +class UpsampleOpBuilder : public BaseOpBuilder { + public: + UpsampleOpBuilder() : BaseOpBuilder("UpsampleOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(UpsampleOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const final ORT_MUST_USE_RESULT; + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const override ORT_MUST_USE_RESULT; + + private: + const std::unordered_map supported_modes = { + {"nearest", QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST}, + {"linear", QNN_OP_RESIZE_INTERPOLATION_MODE_LINEAR}, + {"cubic", QNN_OP_RESIZE_INTERPOLATION_MODE_CUBIC}}; + + // Info for Onnx Upsample attribute {, } + const OnnxAttrInfo onnx_mode_attr = {"mode", "nearest"}; +}; + +static Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector& param_tensor_names, + const Qnn_Scalar_t& qnn_scalar, + const std::string& qnn_scalar_param_name) { + QnnParamWrapper qnn_param_wrapper(node_unit.Index(), node_unit.Name(), qnn_scalar_param_name, qnn_scalar); + param_tensor_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + + return Status::OK(); +} + +Status UpsampleOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + // Resize ops are sensitive with data layout, no special validation so far + // The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW + // The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC + // Need to do op validation in 1st call of GetCapability + if (node_unit.Domain() == kMSInternalNHWCDomain) { + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); + } + + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + NodeAttrHelper node_helper(node_unit); + + // Check mode + const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); + ORT_RETURN_IF_NOT(supported_modes.find(interp_mode) != supported_modes.end(), + "QNN EP: Resize does not support mode ", interp_mode.c_str()); + + const auto& input_0 = node_unit.Inputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), + "QNN EP: Cannot get input shape for Onnx Upsample ", input_0.node_arg.Name().c_str()); + const size_t input_rank = input_shape.size(); + + ORT_RETURN_IF(is_npu_backend && (input_rank < 3 || input_rank > 5), + "QNN EP: The input rank for Resize must be at least 3 and no greater than 5 on the HTP."); + + const auto& output_0 = node_unit.Outputs()[0]; + std::vector output_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), + "QNN EP: Cannot get output shape for Onnx Upsample ", output_0.node_arg.Name().c_str(), + ". Dynamic scales input is not supported in QNN EP."); + + // Check that only the spatial dimensions (width, height) are resized. The batch_size (N) and channels (C) should + // be untouched. This code runs before layout transformation, so we know that the current layout is "channel first" + // (e.g., N, C, S1, S2, ..., SN). + ORT_RETURN_IF_NOT(input_shape[0] == output_shape[0] && input_shape[1] == output_shape[1], + "QNN EP: Resize may only change the spatial dimensions."); + + if (!is_npu_backend) { + ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); + ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), + "QNN EP: Data type ", input_data_type->c_str(), + " is not supported for Resize operator in CPU backend."); + } + + return Status::OK(); +} + +Status UpsampleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const int opset_version = node_unit.SinceVersion(); + const auto& inputs = node_unit.Inputs(); + + if (opset_version > 7 && do_op_validation) { + const std::string& scales_input_name = inputs[1].node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(scales_input_name), + "QNN doesn't support dynamic scales input for ONNX Upsample op ", node_unit.Name().c_str()); + } + + // Only need to consider the first input of Onnx upsample. + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + return Status::OK(); +} + +Status UpsampleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector param_tensor_names; + NodeAttrHelper node_helper(node_unit); + const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); + + const auto& input_0 = node_unit.Inputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), + "QNN EP: Cannot get input shape for Onnx Upsample ", input_0.node_arg.Name().c_str()); + + const size_t input_rank = input_shape.size(); + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + std::string qnn_op_type = GetQnnOpType(node_unit.OpType()); + + if (is_npu_backend && input_rank == 4 && interp_mode != "cubic") { + // Translate QNN's Resize to QNN's ResizeNearestNeighbor/ResizeBilinear to achieve better performance on + // the HTP backend. QNN's ResizeNearestNeighbor and ResizeBilinear are only supported when input rank is 4. + qnn_op_type = (interp_mode == "nearest") ? QNN_OP_RESIZE_NEAREST_NEIGHBOR : QNN_OP_RESIZE_BILINEAR; + + // Parameter 'align_corners' + Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; + qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; + qnn_align_corners.bool8Value = false; + const std::string align_corners_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) + ? QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS + : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS; + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_align_corners, align_corners_param_name)); + + // Parameter 'half_pixel_centers' + Qnn_Scalar_t qnn_half_pixel_centers = QNN_SCALAR_INIT; + qnn_half_pixel_centers.dataType = QNN_DATATYPE_BOOL_8; + qnn_half_pixel_centers.bool8Value = false; + const std::string half_pixel_centers_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR) + ? QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS + : QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS; + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_half_pixel_centers, half_pixel_centers_param_name)); + + if (qnn_op_type == QNN_OP_RESIZE_BILINEAR) { + // Parameter 'antialias' + Qnn_Scalar_t qnn_antialias = QNN_SCALAR_INIT; + qnn_antialias.dataType = QNN_DATATYPE_BOOL_8; + qnn_antialias.bool8Value = false; + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_antialias, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS)); + } + } else { + // Remain as QNN's Resize. + // Parameter 'exclude_outside' + Qnn_Scalar_t qnn_exclude_outside = QNN_SCALAR_INIT; + qnn_exclude_outside.dataType = QNN_DATATYPE_BOOL_8; + qnn_exclude_outside.bool8Value = false; + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_exclude_outside, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE)); + + // Parameter 'transformation_mode' + Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; + qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; + qnn_transformation_mode.uint32Value = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) + ? static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL) + : static_cast(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC); + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE)); + + // Parameter 'interpolation_mode' + Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; + qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; + qnn_interp_mode.uint32Value = static_cast(supported_modes.at(interp_mode)); + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE)); + + // Parameter 'nearest_mode'. Process only when 'interpolation_mode' is NEAREST. + if (qnn_interp_mode.uint32Value == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) { + Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; + qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; + qnn_nearest_mode.uint32Value = static_cast(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR); + + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names, + qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE)); + } + } + + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, qnn_op_type)); + + return Status::OK(); +} + +Status UpsampleOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + QnnQuantParamsWrapper& quant_param) const { + if (!quant_param.IsPerTensor()) { + return Status::OK(); + } + + // Force Resize op's output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + +void CreateUpsampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index aea354d0550b7..edd1f3e9eb53b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1850,9 +1850,10 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont if (did_register) { HtpSharedMemoryAllocator::AllocationCleanUpFn unregister_mem_handle = [&logger = *logger_, + shared_memory_address, weak_backend_manager = weak_from_this(), weak_context_handle_record = std::weak_ptr{context_handle_record}]( - void* shared_memory_address) { + void* /* allocation_base_address */) { // Lock QnnBackendManager shared_ptr to ensure that QNN interface is still valid. auto backend_manager = weak_backend_manager.lock(); if (!backend_manager) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index ee4f385f03889..0d7bc0ba9f4c7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -188,10 +188,6 @@ class QnnTensorWrapper { SetQnnTensorClientBuf(qnn_tensor_, client_buf_); } - if (mem_type != QNN_TENSORMEMTYPE_RAW) { - ORT_THROW("mem_type not supported for now."); - } - SetQnnTensorQParams(qnn_tensor_, quant_params_.Get()); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 3f2faea698259..8421bd4a99196 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -180,14 +180,16 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); + const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name(); + LOGS(logger, ERROR) << message; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); + const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name(); + LOGS(logger, ERROR) << message; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message); } return Status::OK(); @@ -200,7 +202,9 @@ static Status BindQnnTensorMemoryToOrtValueMemory(const logging::Logger& logger, Qnn_ContextHandle_t qnn_context, Qnn_Tensor_t& qnn_tensor) { // either set qnn_tensor memHandle or clientBuf - const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::AssociatedMemoryInfo(); + const static auto htp_shared_mem_info = HtpSharedMemoryAllocator::AssociatedMemoryInfo(); + const bool uses_shared_memory = (ort_value_memory_info.device.Type() == htp_shared_mem_info.device.Type() && + ort_value_memory_info.device.MemType() == htp_shared_mem_info.device.MemType()); if (!uses_shared_memory) { LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory."; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 7e6d7add668dc..39ec9dba18f07 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -68,9 +68,13 @@ Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensor ORT_RETURN_IF_ERROR(UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor)); } + Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW; + if (true == model_settings_.htp_shared_memory) { + mem_type = QNN_TENSORMEMTYPE_MEMHANDLE; + } tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type, std::move(tensor_info.quant_param), std::move(tensor_info.shape), - std::move(unpacked_tensor)); + std::move(unpacked_tensor), mem_type); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 660f719feb32a..9ec6f470af9fd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -30,6 +30,7 @@ struct TensorInfo { struct ModelSettings { bool offload_graph_io_quantization = false; + bool htp_shared_memory = false; }; class QnnModelWrapper { diff --git a/onnxruntime/core/providers/qnn/onnxruntime_providers_qnn.rc b/onnxruntime/core/providers/qnn/onnxruntime_providers_qnn.rc new file mode 100644 index 0000000000000..cf34384203909 --- /dev/null +++ b/onnxruntime/core/providers/qnn/onnxruntime_providers_qnn.rc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file REQUIRES the following external definitions: +// FILE_DESC, FILE_NAME, VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE, and VER_STRING + +#include + +#if defined(DEBUG) || defined(_DEBUG) +#define VER_DEBUG VS_FF_DEBUG +#else +#define VER_DEBUG 0 +#endif + +// ----------------------------------------------------------------------------- + +VS_VERSION_INFO VERSIONINFO +FILEVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +PRODUCTVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +FILEFLAGSMASK VS_FFI_FILEFLAGSMASK +FILEFLAGS VER_DEBUG +FILEOS VOS__WINDOWS32 +FILETYPE VFT_DLL +FILESUBTYPE VFT2_UNKNOWN + +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904E4" + BEGIN + VALUE "CompanyName", "Microsoft Corporation" + VALUE "FileDescription", FILE_DESC + VALUE "FileVersion", VER_STRING + VALUE "InternalName", FILE_DESC + VALUE "LegalCopyright", "\251 Microsoft Corporation. All rights reserved." + VALUE "OriginalFilename", FILE_NAME + VALUE "ProductName", "Microsoft\256 Windows\256 Operating System" + VALUE "ProductVersion", VER_STRING + END + END + + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x409, 1252 + END +END diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index cb92e927ff65a..29e5dc0c25564 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -6,56 +6,19 @@ #include #include #include +#include + +#include #include "core/providers/qnn/ort_api.h" namespace onnxruntime::qnn { -/** - * HtpSharedMemoryAllocator allocation details - * - * The HTP shared memory allocator will allocate a block of shared memory larger than the amount requested in order to - * hold some additional info. - * Each allocation returned by HtpSharedMemoryAllocator::Alloc() is preceded by an AllocationHeader structure. - * - * For example, if Alloc(num_requested_bytes) is called, this is what the memory layout looks like: - * | AllocationHeader bytes | num_requested_bytes bytes | - * ^- address returned by Alloc() - * - * The AllocationHeader can be used to obtain the owning allocator instance, which in turn can be used to do other - * operations with that allocation, such as retrieving more info about the allocation. - */ - namespace { -struct AllocationHeader { - static constexpr std::array kAllocationHeaderMarker{'o', 'r', 't', 'a', 'l', 'l', 'o', 'c'}; - - // Marker bytes to verify as a sanity check. - std::array marker; - - // Pointer to the allocating allocator instance. - // Note: A critical assumption here is that the allocating allocator is not destroyed before the allocation is freed. - HtpSharedMemoryAllocator* allocator_ptr; - - AllocationHeader(HtpSharedMemoryAllocator* allocator_ptr) - : marker{kAllocationHeaderMarker}, - allocator_ptr{allocator_ptr} { - } - - ~AllocationHeader() { - marker.fill('\0'); - allocator_ptr = nullptr; - } -}; - size_t AllocationAlignment() { constexpr size_t min_allocation_alignment = 64; // Equal to MlasGetPreferredBufferAlignment() - return std::max(alignof(AllocationHeader), min_allocation_alignment); -} - -size_t DivRoundUp(size_t a, size_t b) { // TODO is there already a helper function somewhere for this? - return (a + b - 1) / b; + return min_allocation_alignment; } bool IsAligned(const void* address, size_t alignment) { @@ -63,34 +26,88 @@ bool IsAligned(const void* address, size_t alignment) { return (reinterpret_cast(address) & (alignment - 1)) == 0; } -size_t AllocationOffsetFromStartOfHeader() { - const size_t allocation_alignment = AllocationAlignment(); - const size_t offset = DivRoundUp(sizeof(AllocationHeader), allocation_alignment) * allocation_alignment; - return offset; +std::unique_ptr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, + const RpcMemApi& rpcmem_api) { + return {shared_memory_raw, rpcmem_api.free}; } -std::byte* GetAllocationHeaderAddress(void* allocation_address) { - auto* allocation_header_address = reinterpret_cast(allocation_address) - sizeof(AllocationHeader); - return allocation_header_address; +// This class tracks information about allocations made by `HtpSharedMemoryAllocator` instances. +// Given an address within a tracked allocation, we can look up information about it like the base address and the +// allocating allocator instance. +class AllocationTracker { + public: + struct Record { + void* base_address; + size_t size_in_bytes; + gsl::not_null allocator; + }; + + // Starts tracking an allocation. + // Returns true if successful, or false if there is already a tracked allocation at `base_address`. + bool RegisterAllocation(void* base_address, size_t size_in_bytes, HtpSharedMemoryAllocator& allocator); + + // Stops tracking an allocation. + // Returns true if successful, or false if there is no tracked allocation at `base_address`. + bool UnregisterAllocation(void* base_address); + + // Looks up a tracked allocation's record. + // Returns the record associated with the tracked allocation containing `address_within_allocation`, + // or `std::nullopt` if there is no such tracked allocation. + std::optional LookUp(void* address_within_allocation); + + private: + std::map records_; + std::shared_mutex records_mutex_; +}; + +bool AllocationTracker::RegisterAllocation(void* base_address, size_t size_in_bytes, + HtpSharedMemoryAllocator& allocator) { + Record record{base_address, size_in_bytes, &allocator}; + + std::unique_lock write_lock{records_mutex_}; + const bool registered = records_.emplace(base_address, std::move(record)).second; + return registered; } -AllocationHeader& ValidateAllocationAddressAndGetHeader(void* allocation_address) { - const size_t allocation_alignment = AllocationAlignment(); - ORT_ENFORCE(IsAligned(allocation_address, allocation_alignment), - "Allocation address (", allocation_address, ") does not have required alignment (", - allocation_alignment, " bytes)."); +bool AllocationTracker::UnregisterAllocation(void* base_address) { + std::unique_lock write_lock{records_mutex_}; + const bool unregistered = records_.erase(base_address) == 1; + return unregistered; +} - auto* allocation_header = reinterpret_cast(GetAllocationHeaderAddress(allocation_address)); - ORT_ENFORCE(allocation_header->marker == AllocationHeader::kAllocationHeaderMarker, - "AllocationHeader for allocation address (", allocation_address, - ") does not have the expected marker bytes."); +std::optional AllocationTracker::LookUp(void* address_within_allocation) { + std::shared_lock read_lock{records_mutex_}; - return *allocation_header; + // Look for a record where `address_within_allocation` falls within the range: + // [`record.base_address`, `record.base_address` + `record.size_in_bytes`) + + // First, find the first record with a base address greater than `address_within_allocation`, or the end of the + // container if no such record exists. + const auto first_record_with_larger_base_address_it = records_.upper_bound(address_within_allocation); + + // The previous record should have the greatest base address that is not greater than `address_within_allocation`. + // Make sure it exists. + if (first_record_with_larger_base_address_it == records_.begin()) { + return std::nullopt; + } + + const auto record_it = std::prev(first_record_with_larger_base_address_it); + + const auto record = record_it->second; + assert(address_within_allocation >= record.base_address); + + // Verify that `address_within_allocation` is within the upper end of the range. + if (reinterpret_cast(address_within_allocation) >= + reinterpret_cast(record.base_address) + record.size_in_bytes) { + return std::nullopt; + } + + return record; } -std::unique_ptr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, - const RpcMemApi& rpcmem_api) { - return {shared_memory_raw, rpcmem_api.free}; +AllocationTracker& GlobalAllocationTracker() { + static AllocationTracker allocation_tracker{}; + return allocation_tracker; } } // namespace @@ -110,8 +127,7 @@ HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptrApi().to_fd(shared_memory.get()); ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); - std::byte* allocation_address = reinterpret_cast(shared_memory_raw) + allocation_offset; + std::byte* allocation_address = reinterpret_cast(shared_memory_raw); // store allocation record { SharedMemoryInfo shared_memory_info{}; shared_memory_info.fd = shared_memory_fd; - shared_memory_info.offset = allocation_offset; + shared_memory_info.offset = 0; shared_memory_info.total_size = shared_memory_block_size_in_bytes; AllocationRecord allocation_record{}; @@ -150,10 +166,14 @@ void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { ORT_ENFORCE(inserted, "Allocation record already exists for address (", allocation_address, ")."); } - // initialize header + // register with global allocation tracker { - std::byte* allocation_header_address = GetAllocationHeaderAddress(allocation_address); - new (allocation_header_address) AllocationHeader(this); + const bool registered = GlobalAllocationTracker().RegisterAllocation(allocation_address, + shared_memory_block_size_in_bytes, + *this); + + ORT_ENFORCE(registered, "Attempted to register allocation but it is already tracked for address (", + allocation_address, ")."); } shared_memory.release(); @@ -165,11 +185,6 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { return; } - auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); - ORT_ENFORCE(allocation_header.allocator_ptr == this, - "AllocationHeader points to a different allocator (", allocation_header.allocator_ptr, - ") than this one (", this, ")."); - const auto allocation_node = [this, allocation_address]() { std::scoped_lock g{allocations_mutex_}; return allocations_.extract(allocation_address); @@ -181,12 +196,16 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { // Avoid throwing exceptions as this may be running from a destructor. try { // take ownership of shared memory and free at end of scope - const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); - void* raw_allocation_address = (void*)((std::byte*)allocation_address - allocation_offset); - auto shared_memory = WrapSharedMemoryWithUniquePtr(raw_allocation_address, rpcmem_lib_->Api()); - - // destroy header - allocation_header.~AllocationHeader(); + auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + + // unregister with global allocation tracker + { + const bool unregistered = GlobalAllocationTracker().UnregisterAllocation(allocation_address); + if (!unregistered) { + LOGS(logger_, ERROR) << "Attempted to deregister allocation but it is untracked for address (" + << allocation_address << ")."; + } + } // clean up allocation record const auto& allocation_record = allocation_node.mapped(); @@ -204,39 +223,55 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { } } -Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(void* allocation_address, - SharedMemoryInfo& allocation_info) { - auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); - return allocation_header.allocator_ptr->GetAllocationSharedMemoryInfoForThisAllocator(allocation_address, - allocation_info); +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(void* address_within_allocation, + SharedMemoryInfo& shared_memory_info_out) { + const auto tracked_record = GlobalAllocationTracker().LookUp(address_within_allocation); + ORT_RETURN_IF_NOT(tracked_record.has_value(), "Failed to look up tracked allocation."); + + void* const base_address = tracked_record->base_address; + SharedMemoryInfo shared_memory_info{}; + ORT_RETURN_IF_ERROR(tracked_record->allocator->GetAllocationSharedMemoryInfoForThisAllocator( + base_address, shared_memory_info)); + + // adjust `shared_memory_info.offset` for `address_within_allocation` + const auto offset_from_base = std::distance(reinterpret_cast(base_address), + reinterpret_cast(address_within_allocation)); + + shared_memory_info.offset += offset_from_base; + + shared_memory_info_out = std::move(shared_memory_info); + return Status::OK(); } -Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* allocation_address, +Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* address_within_allocation, AllocationCleanUpFn&& allocation_clean_up) { - auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); - return allocation_header.allocator_ptr->AddAllocationCleanUpForThisAllocator(allocation_address, - std::move(allocation_clean_up)); + const auto tracked_record = GlobalAllocationTracker().LookUp(address_within_allocation); + ORT_RETURN_IF_NOT(tracked_record.has_value(), "Failed to look up tracked allocation."); + + void* const base_address = tracked_record->base_address; + return tracked_record->allocator->AddAllocationCleanUpForThisAllocator(base_address, + std::move(allocation_clean_up)); } -Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_base_address, SharedMemoryInfo& allocation_info) { std::scoped_lock g{allocations_mutex_}; - const auto allocation_it = allocations_.find(allocation_address); + const auto allocation_it = allocations_.find(allocation_base_address); ORT_RETURN_IF(allocation_it == allocations_.end(), - "Failed to get allocation info for address (", allocation_address, ")."); + "Failed to get allocation info for address (", allocation_base_address, ")."); allocation_info = allocation_it->second.shared_memory_info; return Status::OK(); } -Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allocation_address, +Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allocation_base_address, AllocationCleanUpFn&& allocation_clean_up) { ORT_RETURN_IF(allocation_clean_up == nullptr, "allocation_clean_up should not be empty."); std::scoped_lock g{allocations_mutex_}; - const auto allocation_it = allocations_.find(allocation_address); + const auto allocation_it = allocations_.find(allocation_base_address); ORT_RETURN_IF(allocation_it == allocations_.end(), - "Failed to get allocation info for address (", allocation_address, ")."); + "Failed to get allocation info for address (", allocation_base_address, ")."); auto& clean_up_fns = allocation_it->second.clean_up_fns; clean_up_fns.emplace_back(std::move(allocation_clean_up)); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index e64f38f494b35..f91383cb788df 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -34,24 +34,27 @@ class HtpSharedMemoryAllocator : public IAllocator { }; // Gets an allocation's shared memory info. - // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been - // freed. - static Status GetAllocationSharedMemoryInfo(void* allocation_address, + // `address_within_allocation` identifies the allocation. It must be an address within an allocation returned by + // Alloc() which has not yet been freed. + static Status GetAllocationSharedMemoryInfo(void* address_within_allocation, SharedMemoryInfo& allocation_info); - using AllocationCleanUpFn = std::function; + // Allocation clean up callback signature. + // For a given allocation, any added clean up callbacks will be called with the allocation's base address when the + // allocation is freed. + using AllocationCleanUpFn = std::function; // Adds allocation clean up callback to call when the allocation is freed. - // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been - // freed. + // `address_within_allocation` identifies the allocation. It must be an address within an allocation returned by + // Alloc() which has not yet been freed. // `allocation_clean_up` is the clean up callback. The associated allocator takes ownership of the callback. - static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); + static Status AddAllocationCleanUp(void* address_within_allocation, AllocationCleanUpFn&& allocation_clean_up); private: - Status GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, + Status GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_base_address, SharedMemoryInfo& allocation_info); - Status AddAllocationCleanUpForThisAllocator(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); + Status AddAllocationCleanUpForThisAllocator(void* allocation_base_address, AllocationCleanUpFn&& allocation_clean_up); struct AllocationRecord { SharedMemoryInfo shared_memory_info; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index ed5fd60fc71d8..2d117927cbaf7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -431,6 +431,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio // Initialize rpcmem_library_. // This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available. rpcmem_library_ = std::make_shared(); + model_settings_.htp_shared_memory = true; } dump_json_qnn_graph_ = ParseBoolOption("dump_json_qnn_graph", false, provider_options_map); @@ -1320,4 +1321,14 @@ std::vector QNNExecutionProvider::CreatePreferredAllocators() { return allocators; } +OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) const { + // We are disabling the HTP shared memory allocator for intermediate values + // until we learn how to deal with memhandle costs. + // if (IsHtpSharedMemoryAllocatorAvailable()) { + // return qnn::HtpSharedMemoryAllocator::AssociatedMemoryInfo().device; + //} + // Default CPU allocator + return default_device_; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index d7a5d04d22692..7769a4a453c1b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -51,6 +51,8 @@ class QNNExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; + OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 26adc0aaa8686..58d4461c7c32a 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -435,6 +435,7 @@ void InitProviderOrtApi(); inline Env& GetDefaultEnv() { return g_host->Env__Default(); } + } // namespace onnxruntime #define CREATE_MESSAGE(logger, severity, category, datatype) \ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index ff29de6aa71db..eee6a05f12729 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -525,7 +525,7 @@ Status NonMaxSuppressionBase::GetThresholdsFromInputs(const PrepareContext& pc, Status GatherBase::PrepareForCompute(OpKernelContext* context, GatherBase::Prepare& p) const { return g_host_cpu.GatherBase__PrepareForCompute(this, context, reinterpret_cast(p)); } Status UnsqueezeBase::PrepareCompute(OpKernelContext* ctx, UnsqueezeBase::Prepare& p) const { return g_host_cpu.UnsqueezeBase__PrepareCompute(this, ctx, reinterpret_cast(p)); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) bool TileOp::IsTileMemcpy(const TensorShape& input_shape, const int64_t* repeats, size_t rank, bool& is_batched_memcpy, size_t& num_of_elements_per_batch, size_t& num_of_copies_per_batch, size_t& num_of_batch_copies) { return g_host_cpu.TileOp__IsTileMemcpy(input_shape, repeats, rank, is_batched_memcpy, num_of_elements_per_batch, num_of_copies_per_batch, num_of_batch_copies); } diff --git a/onnxruntime/core/providers/tensorrt/onnxruntime_providers_tensorrt.rc b/onnxruntime/core/providers/tensorrt/onnxruntime_providers_tensorrt.rc new file mode 100644 index 0000000000000..891f6a279510b --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/onnxruntime_providers_tensorrt.rc @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file REQUIRES the following external definitions: +// FILE_NAME, VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE, and VER_STRING + +#include + +#if defined(DEBUG) || defined(_DEBUG) +#define VER_DEBUG VS_FF_DEBUG +#else +#define VER_DEBUG 0 +#endif + +// ----------------------------------------------------------------------------- + +VS_VERSION_INFO VERSIONINFO +FILEVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +PRODUCTVERSION VER_MAJOR, VER_MINOR, VER_BUILD, VER_PRIVATE +FILEFLAGSMASK VS_FFI_FILEFLAGSMASK +FILEFLAGS VER_DEBUG +FILEOS VOS__WINDOWS32 +FILETYPE VFT_DLL +FILESUBTYPE VFT2_UNKNOWN + +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904E4" + BEGIN + VALUE "CompanyName", "Microsoft Corporation" + VALUE "FileDescription", "ONNX Runtime TensorRT Provider" + VALUE "FileVersion", VER_STRING + VALUE "InternalName", "ONNX Runtime TensorRT Provider" + VALUE "LegalCopyright", "\251 Microsoft Corporation. All rights reserved." + VALUE "OriginalFilename", FILE_NAME + VALUE "ProductName", "Microsoft\256 Windows\256 Operating System" + VALUE "ProductVersion", VER_STRING + END + END + + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x409, 1252 + END +END diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index b51e882629177..60f115ca50da4 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -268,10 +268,10 @@ void initialize_vitisai_ep() { } void deinitialize_vitisai_ep() { + vaip::deregister_xir_ops(s_domains_vitisaiep); if (s_library_vitisaiep.deinitialize_onnxruntime_vitisai_ep) { s_library_vitisaiep.deinitialize_onnxruntime_vitisai_ep(); } - vaip::deregister_xir_ops(s_domains_vitisaiep); // kernel registry would be repopulated, no need to delete kernel registry s_domains_vitisaiep.clear(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index d40da70726b43..63949116507e4 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (16u) +#define VAIP_ORT_API_MAJOR (17u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index ab8a95b38491d..1d812779da265 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -10,8 +10,9 @@ #include // 1st-party headers/libs. -#include "core/platform/env_var_utils.h" #include "core/common/exceptions.h" +#include "core/platform/env_var_utils.h" +#include "core/providers/qnn/ort_api.h" #include "vaip/capability.h" #include "vaip/global_api.h" @@ -25,7 +26,11 @@ constexpr const char* VITISAI = "VITISAI"; VitisAIExecutionProvider::VitisAIExecutionProvider( const ProviderOptions& info) - : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { + : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID, + kAlloc4KAlignment)}, + info_(info) { auto it = info_.find("ep_context_enable"); ep_ctx_enabled_ = it != info_.end() && it->second == "1"; it = info_.find("ep_context_embed_mode"); @@ -140,4 +145,24 @@ common::Status VitisAIExecutionProvider::SetEpDynamicOptions(gsl::span VitisAIExecutionProvider::GetProfiler() { return std::make_unique(); } + +std::vector VitisAIExecutionProvider::CreatePreferredAllocators() { + std::vector result; + // We do not want arena for this, as it would not respect alignment. + constexpr const bool use_arena_false = false; + AllocatorCreationInfo device_info_cpu_aligned_4k{ + [](OrtDevice::DeviceId device_id) { + return std::make_unique( + OrtMemoryInfo( + onnxruntime::CPU_ALIGNED_4K, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + device_id, + kAlloc4KAlignment))); + }, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID, use_arena_false}; + + result.push_back(CreateAllocator(device_info_cpu_aligned_4k)); + return result; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index f72f8cc721fbd..8db4f36dd497a 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -46,6 +46,8 @@ class VitisAIExecutionProvider : public IExecutionProvider { virtual common::Status SetEpDynamicOptions(gsl::span /*keys*/, gsl::span /*values*/) override; + std::vector CreatePreferredAllocators() override; + private: using my_ep_t = vaip_core::DllSafe>>; using my_ep_uptr_t = std::shared_ptr; diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h index 692dbc833f0a7..945abde5b40ad 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/base_op_builder.h @@ -49,7 +49,7 @@ class BaseOpBuilder : public IOpBuilder { virtual bool IsQuantizedOp(const NodeUnit& /* node_unit */) const { return false; } virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } - virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 22; } + virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 23; } virtual bool HasSupportedInputOutputsImpl( const InitializedTensorSet& initializers, const NodeUnit& node_unit) const; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 3b5daef04dd50..76cf9c9a797e1 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -24,6 +24,7 @@ #include #include #include + #include "core/framework/compute_capability.h" #include "core/providers/vsinpu/vsinpu_execution_provider.h" #include "core/providers/vsinpu/vsinpu_ep_graph.h" @@ -36,25 +37,15 @@ #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/partitioning_utils.h" +#include "core/providers/qnn/ort_api.h" + namespace onnxruntime { VSINPUExecutionProvider::VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider}, + : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID, + kAlloc4KAlignment)}, device_id_(info.device_id) { - AllocatorCreationInfo default_memory_info{ - [](int) { - return std::make_unique( - OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator)); - }}; - - CreateAllocator(default_memory_info); - - AllocatorCreationInfo cpu_memory_info{ - [](int) { - return std::make_unique( - OrtMemoryInfo("VSINPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeCPUOutput)); - }}; - - CreateAllocator(cpu_memory_info); } VSINPUExecutionProvider::~VSINPUExecutionProvider() {} @@ -281,4 +272,23 @@ std::shared_ptr VSINPUExecutionProvider::GetKernelRegistry() con return kernel_registry; } +std::vector VSINPUExecutionProvider::CreatePreferredAllocators() { + std::vector result; + // We do not want arena for this, as it would not respect alignment. + constexpr const bool use_arena_false = false; + AllocatorCreationInfo device_info_cpu_aligned_4k{ + [](OrtDevice::DeviceId device_id) { + return std::make_unique( + OrtMemoryInfo( + onnxruntime::CPU_ALIGNED_4K, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + device_id, + kAlloc4KAlignment))); + }, + device_id_, use_arena_false}; + + result.push_back(CreateAllocator(device_info_cpu_aligned_4k)); + return result; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 1c0b8b63a8e6c..1f96bde81b1d6 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -43,6 +43,8 @@ class VSINPUExecutionProvider : public IExecutionProvider { const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; std::shared_ptr GetKernelRegistry() const override; + std::vector CreatePreferredAllocators() override; + Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; std::mutex& GetMutex() { return mutex_; } diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 3117208c7be7d..f2569fce6b5eb 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -88,6 +88,9 @@ class ComputeContext { // // Create CPU tensor. // + // This method creates a tensor of the given data type and shape, using the CPU allocator. + // The tensor owns the underlying CPU memory buffer. + // template Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; @@ -98,6 +101,9 @@ class ComputeContext { // // Create GPU tensor. // + // This method creates a tensor of the given data type and shape, using the WebGPU allocator. + // The tensor owns the underlying WebGPU storage buffer. + // template Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; diff --git a/onnxruntime/core/providers/webgpu/math/einsum.cc b/onnxruntime/core/providers/webgpu/math/einsum.cc new file mode 100644 index 0000000000000..3595496f8450d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/einsum.cc @@ -0,0 +1,461 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/einsum.h" + +#include +#include +#include +#include + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +// Regular expressions for equation parsing. +static const std::regex symbol_pattern("[a-zA-Z]|\\.\\.\\."); +static const std::regex term_pattern("([a-zA-Z]|\\.\\.\\.)+"); +// Term can be empty in some cases like ,...i->...i, so allow empty term here. +static const std::regex lhs_pattern("(([a-zA-Z]|\\.\\.\\.)*,)*([a-zA-Z]|\\.\\.\\.)*"); + +// Helper function to remove all whitespaces in a given string. +std::string RemoveAllWhitespace(const std::string& str) { + std::string result = str; + result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end()); + return result; +} + +bool IsInteger(const std::string& s) { + static const std::regex pattern(R"(^\d+$)"); + return std::regex_match(s, pattern); +} +} // namespace + +#define WEBGPU_EINSUM_TYPED_KERNEL_DECL(version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Einsum, kOnnxDomain, version, float, kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Einsum); + +WEBGPU_EINSUM_TYPED_KERNEL_DECL(12); + +EinsumEquation::EinsumEquation(const std::vector& inputs, + const std::string& raw_equation) { + std::string lhs, rhs, equation = RemoveAllWhitespace(raw_equation); + size_t arrow_pos = equation.find("->"); + if (arrow_pos != std::string::npos) { + lhs = equation.substr(0, arrow_pos); + rhs = equation.substr(arrow_pos + 2); + } else { + lhs = equation; + rhs = ""; + } + + if (!std::regex_match(lhs, lhs_pattern)) { + ORT_THROW("Invalid LHS term"); + } + + // Parse LHS terms. + size_t pos = 0; + size_t find; + int input_idx = 0; + while ((find = lhs.find(',', pos)) != std::string::npos) { + auto term = lhs.substr(pos, find - pos); + if (!term.empty() && !std::regex_match(term, term_pattern)) { + ORT_THROW("Invalid LHS term"); + } + auto dims = inputs[input_idx]->Shape().GetDims(); + lhs_.push_back(ProcessTerm(term, true, dims, input_idx)); + pos = find + 1; + input_idx++; + } + auto last_term = lhs.substr(pos); + if (!last_term.empty() && !std::regex_match(last_term, term_pattern)) { + ORT_THROW("Invalid LHS term"); + } + auto dims = inputs[input_idx]->Shape().GetDims(); + lhs_.push_back(ProcessTerm(last_term, true, dims, input_idx)); + + if (!rhs.empty() && !std::regex_match(rhs, term_pattern)) { + ORT_THROW("Invalid RHS term"); + } + + // Handle empty RHS differently for implicit vs explicit modes. + // Implicit mode - arrow is not in the equation where the equation "ij,jk" equals to "ij,jk->ik" + // which is actually a matrix multiplication. + // Explicit mode - arrow is in the equation where the equation "ij,jk->" contains two steps, first + // step is a matrix multiplication just like the implicit mode, and the second step is to sum up + // the matrix produced by the first step to a scalar. + bool is_implicit_mode = arrow_pos == std::string::npos; + if (rhs.empty() && is_implicit_mode) { + // Implicit mode without RHS specified - construct output with repeated symbols + bool ellipsis_dim_calculated = false; + for (const auto& pair : symbol_to_info_) { + // Skip when symbol appears multiple times (except ellipsis dimensions) + // or when ellipsis dimensions have already been processed. + bool is_ellipsis_dim_symbol = IsInteger(pair.first); + bool should_skip = ((!is_ellipsis_dim_symbol && pair.second.count != 1) || + (is_ellipsis_dim_symbol && ellipsis_dim_calculated)); + + if (should_skip) { + continue; + } + + if (IsInteger(pair.first)) { + rhs += "..."; + ellipsis_dim_calculated = true; + } else { + rhs += pair.first; + } + } + } + + // Compute output dims. + std::sregex_iterator it(rhs.begin(), rhs.end(), symbol_pattern); + std::sregex_iterator end; + for (; it != end; ++it) { + std::string symbol = it->str(); + if (symbol == "...") { + output_dims.insert(output_dims.end(), ellipsis_dims_.begin(), ellipsis_dims_.end()); + } else { + auto info_it = symbol_to_info_.find(symbol); + if (info_it == symbol_to_info_.end()) { + ORT_THROW("Invalid RHS symbol"); + } + output_dims.push_back(info_it->second.dim_value); + } + } + + rhs_ = ProcessTerm(rhs, false, output_dims); +} + +void EinsumEquation::AddSymbol(const std::string& symbol, int64_t dim_value, int input_index) { + auto it = symbol_to_info_.find(symbol); + if (it != symbol_to_info_.end()) { + if (it->second.dim_value != dim_value && it->second.count != 1) { + ORT_THROW("Dimension mismatch"); + } + it->second.count++; + it->second.input_indices.push_back(input_index); + } else { + SymbolInfo info; + info.count = 1; + info.dim_value = dim_value; + info.input_indices.push_back(input_index); + symbol_to_info_[symbol] = info; + } +} + +EinsumTerm EinsumEquation::ProcessTerm(const std::string& term, bool is_input, + gsl::span dims, int index) { + EinsumTerm einsum_term; + einsum_term.input_index = index; + + // If the term is empty, return the einsum_term with empty symbol_to_indices. + // This is important for the case where the equation contains scalar like ",i...,->i...", in which + // case the term is empty. We need the term to generate the correct shader code. + if (term.empty()) { + return einsum_term; + } + + const size_t rank = dims.size(); + bool ellipsis = false; + std::vector ellipsis_dims; + size_t next_dim = 0; + + std::sregex_iterator it(term.begin(), term.end(), symbol_pattern); + std::sregex_iterator end; + for (size_t i = 0; it != end; ++it, ++i) { + std::string symbol = it->str(); + if (symbol == "...") { + if (ellipsis) { + ORT_THROW("Only one ellipsis is allowed per input term"); + } + ellipsis = true; + std::sregex_iterator symbol_it(term.begin(), term.end(), symbol_pattern); + std::sregex_iterator symbol_end; + size_t symbol_distance = std::distance(symbol_it, symbol_end) - 1; + if (rank < symbol_distance) { + ORT_THROW("Ellipsis out of bounds"); + } + ellipsis_dims.assign(dims.begin() + next_dim, + dims.begin() + (next_dim + rank - symbol_distance)); + if (has_ellipsis_) { + if (ellipsis_dims_ != ellipsis_dims) { + ORT_THROW("Ellipsis dimensions mismatch"); + } + } else if (is_input) { + has_ellipsis_ = true; + ellipsis_dims_ = ellipsis_dims; + } else { + ORT_THROW("Ellipsis must be specified in the LHS"); + } + // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions. + for (size_t j = 0; j < ellipsis_dims.size(); ++j) { + std::string symbol_j = std::to_string(j); + einsum_term.symbol_to_indices[symbol_j].push_back(i + j); + AddSymbol(symbol_j, dims[next_dim++], index); + } + } else { + einsum_term.symbol_to_indices[symbol].push_back( + i + (has_ellipsis_ ? ellipsis_dims_.size() - 1 : 0)); + AddSymbol(symbol, dims[next_dim++], index); + } + } + return einsum_term; +} + +Status EinsumProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Add inputs and output. + const ShaderVariableHelper& input0 = + shader.AddInput("input0", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + std::vector> inputs; + inputs.push_back(input0); + + for (size_t i = 1; i < input_count_; ++i) { + inputs.push_back(shader.AddInput("input" + std::to_string(i), + ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + + const ShaderVariableHelper& output = + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + // Helper variables for shader generation. + std::string init_prod = "var prod = 1.0;"; + std::string init_sum = "var sum = 0.0;"; + std::string update_sum = "sum += prod;"; + std::vector idx_copy; + std::vector reduce_ops; + std::vector reduce_ops_set_indices; + std::vector reduce_ops_loop_headers; + std::vector reduce_ops_loop_footers; + std::vector reduce_op_compute; + bool is_reduce_ops_without_loop = + parsed_equation_.symbol_to_info_.size() == parsed_equation_.rhs_.symbol_to_indices.size(); + std::set uniform_symbol_set; + for (const auto& pair : parsed_equation_.symbol_to_info_) { + const std::string& symbol = pair.first; + const SymbolInfo& info = pair.second; + if (parsed_equation_.rhs_.symbol_to_indices.find(symbol) != + parsed_equation_.rhs_.symbol_to_indices.end()) { + // Find the indices in the right-hand side (output) term for the current symbol + auto rhs_indices = parsed_equation_.rhs_.symbol_to_indices.find(symbol); + // Skip if symbol doesn't appear in output or has no indices + // This means this symbol is not needed for output calculation + if (rhs_indices == parsed_equation_.rhs_.symbol_to_indices.end() || + rhs_indices->second.empty()) { + continue; + } + + int lhs_term_index = 0; + for (const auto& term : parsed_equation_.lhs_) { + // Skip if the current input tensor index is not associated with this symbol + // This check ensures we only process input indices that actually have this symbol. + if (std::find(info.input_indices.begin(), info.input_indices.end(), lhs_term_index) == + info.input_indices.end()) { + lhs_term_index++; + continue; + } + + auto it = term.symbol_to_indices.find(symbol); + if (it == term.symbol_to_indices.end()) { + ORT_THROW("Invalid symbol error"); + } + + // For each input index associated with the current symbol in this term + for (auto input_index : it->second) { + // Copy output indices to input indices for dimensions that appear in both input and + // output Example: For equation "ij,jk->ik", when symbol='i', this copies the 'i' index + // from output to input0 Format like: input0Indices[0] = outputIndices[0], for the 'i' + // symbol + idx_copy.push_back(inputs[lhs_term_index].get().IndicesSet( + "input" + std::to_string(lhs_term_index) + "Indices", std::to_string(input_index), + output.IndicesGet("outputIndices", std::to_string(rhs_indices->second[0])))); + } + + lhs_term_index++; + } + } else { + int lhs_term_index = 0; + for (const auto& term : parsed_equation_.lhs_) { + // Always construct the string for multiplying the input value to the product accumulator + // Format like: prod *= get_input0_by_indices(input0Indices); + std::string get_indices_str = "prod *= " + + inputs[lhs_term_index].get().GetByIndices( + "input" + std::to_string(lhs_term_index) + "Indices") + + ";"; + + // Only add this computation to reduce_op_compute if it hasn't been added before + // This prevents duplicate multiplications for the same input term since the same symbol + // can appear in multiple terms. + if (std::find(reduce_op_compute.begin(), reduce_op_compute.end(), get_indices_str) == + reduce_op_compute.end()) { + reduce_op_compute.push_back(get_indices_str); + } + + // Skip if the current input tensor index is not associated with this symbol + // This check ensures we only process input indices that actually have this symbol. + if (std::find(info.input_indices.begin(), info.input_indices.end(), lhs_term_index) == + info.input_indices.end()) { + lhs_term_index++; + continue; + } + + auto it = term.symbol_to_indices.find(symbol); + if (it == term.symbol_to_indices.end()) { + ORT_THROW("Invalid symbol error"); + } + + for (auto input_index : it->second) { + // Set the input indices for the current input tensor at the given input_index position + // Format like: input0Indices[1] = j, given equation "ij,jk->ik". + reduce_ops_set_indices.push_back(inputs[lhs_term_index].get().IndicesSet( + "input" + std::to_string(lhs_term_index) + "Indices", std::to_string(input_index), + symbol)); + + // Check if we've already processed this symbol to avoid duplicate loop generation + if (uniform_symbol_set.find(symbol) == uniform_symbol_set.end()) { + // Add symbol to tracked set to prevent duplicate processing + uniform_symbol_set.insert(symbol); + + // Generate a WGSL loop header for reduction over this dimension + // Format like: for(var j: u32 = 0; j < uniforms.input0_shape[1]; j++) {, given equation + // "ij,jk->ik". + reduce_ops_loop_headers.push_back("for(var " + symbol + ": u32 = 0; " + symbol + " < " + + "uniforms.input" + std::to_string(lhs_term_index) + + "_shape[" + std::to_string(input_index) + "]; " + + symbol + "++) {"); + + // Add corresponding loop closing brace + reduce_ops_loop_footers.push_back("}"); + } + } + + lhs_term_index++; + } + } + } + + // Generate shader code based on reduction type. + if (is_reduce_ops_without_loop) { + // Direct multiplication without reduction loops. + std::string sum_statement = "let sum = " + inputs[0].get().GetByIndices("input0Indices"); + for (size_t i = 1; i < inputs.size(); ++i) { + sum_statement += + " * " + inputs[i].get().GetByIndices("input" + std::to_string(i) + "Indices"); + } + + sum_statement += ";"; + reduce_ops.push_back(sum_statement); + } else { + // Reduction operation with loops. + reduce_ops.push_back(init_sum); + for (const auto& header : reduce_ops_loop_headers) { + reduce_ops.push_back(header); + } + for (const auto& set_idx : reduce_ops_set_indices) { + reduce_ops.push_back(set_idx); + } + reduce_ops.push_back(init_prod); + for (const auto& compute : reduce_op_compute) { + reduce_ops.push_back(compute); + } + reduce_ops.push_back(update_sum); + for (const auto& footer : reduce_ops_loop_footers) { + reduce_ops.push_back(footer); + } + } + + // Add safety check to ensure workgroup sizes don't exceed output tensor dimensions + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"); + + // Special handling for scalar output + bool is_scalar_output = parsed_equation_.output_dims.empty(); + if (is_scalar_output) { + // For scalar output, only process the first workgroup thread. This is a special case where the + // output is a single scalar value. The global index is set to 0, and the rest of the threads + // are ignored. This is important for the case where the equation is finally reduced to a + // scalar. For example, the equation "ij->" is a matrix summation and the output is a scalar, + // the shader code will only execute for the first workgroup thread. There may be some space for + // optimization here. + shader.MainFunctionBody() << "if (global_idx != 0u) { return; }\n"; + } else { + // Convert global linear index to N-dimensional indices for the output tensor + // This maps a 1D global thread ID to the corresponding N-D output tensor coordinates + shader.MainFunctionBody() << "var outputIndices = " << output.OffsetToIndices("global_idx") + << ";\n"; + } + + // Define input indices with appropriate types. + for (size_t i = 0; i < input_count_; i++) { + shader.MainFunctionBody() << "var input" << i << "Indices: input" << std::to_string(i) + << "_indices_t;\n"; + } + + // Copy output indices to input indices. + for (const auto& idx : idx_copy) { + shader.MainFunctionBody() << idx << "\n"; + } + + // Add reduce operations. + for (const auto& op : reduce_ops) { + shader.MainFunctionBody() << op << "\n"; + } + + // Handle output value assignment based on the output type (scalar or tensor) + if (is_scalar_output) { + // For scalar output, write the sum to the first (and only) output element at offset 0 + shader.MainFunctionBody() << output.SetByOffset("0", "sum") << "\n"; + } else { + // For tensor output, write the sum to the output element at the current global thread index + // This maps each thread's result to the corresponding position in the output tensor + shader.MainFunctionBody() << output.SetByOffset("global_idx", "sum") << "\n"; + } + + return Status::OK(); +} + +Status Einsum::ComputeInternal(ComputeContext& context) const { + if (context.InputCount() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Einsum requires at least one input tensor."); + } + + std::vector input_tensors; + for (int i = 0; i < context.InputCount(); ++i) { + input_tensors.push_back(context.Input(i)); + } + + // TODO: The EinsumEquation initialization could potentially be done during model loading + // based on input/output shape inference results. This would improve runtime performance + // by avoiding redundant initialization on every compute call. + EinsumEquation equation(input_tensors, equation_); + const std::vector& output_dims = equation.output_dims; + Tensor* Y = context.Output(0, output_dims); + int64_t output_size = Y->Shape().Size(); + if (output_size == 0) { + return Status::OK(); + } + + // Create program with input count and the parsed equation. + EinsumProgram program{input_tensors.size(), equation}; + + for (size_t i = 0; i < input_tensors.size(); ++i) { + program.AddInput({input_tensors[i], ProgramTensorMetadataDependency::TypeAndRank}); + } + + // Add output and base uniforms. + program.CacheHint(equation_) + .SetDispatchGroupSize(static_cast((output_size + 63) / 64)) + .AddOutput({Y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({static_cast(output_size)}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/einsum.h b/onnxruntime/core/providers/webgpu/math/einsum.h new file mode 100644 index 0000000000000..df5d7e8ddd0bc --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/einsum.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { +struct SymbolInfo { + size_t count{0}; + std::vector input_indices; + int64_t dim_value{0}; +}; + +struct EinsumTerm { + // Indices of the symbols in the term which cannot be negative. + std::map> symbol_to_indices; + // The index of the input tensor in the Einsum equation. + // This is -1 for the output term. + int input_index{-1}; +}; + +class EinsumEquation { + public: + EinsumEquation(const std::vector& inputs, const std::string& equation); + std::vector output_dims; + std::map symbol_to_info_; + std::vector lhs_; + EinsumTerm rhs_; + + private: + bool has_ellipsis_{false}; + std::vector ellipsis_dims_; + void AddSymbol(const std::string& symbol, int64_t dim_value, int input_index); + EinsumTerm ProcessTerm(const std::string& term, + bool is_input, + gsl::span dims, + int index = -1); +}; + +class EinsumProgram final : public Program { + public: + EinsumProgram(size_t input_count, const EinsumEquation& parsed_equation) + : Program{"Einsum"}, input_count_(input_count), parsed_equation_{parsed_equation} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t input_count_; + const EinsumEquation& parsed_equation_; +}; + +class Einsum final : public WebGpuKernel { + public: + Einsum(const OpKernelInfo& info) : WebGpuKernel(info) { + std::string equation; + ORT_ENFORCE(info.GetAttr("equation", &equation).IsOK()); + equation_ = equation; + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + std::string equation_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/gemm.cc b/onnxruntime/core/providers/webgpu/math/gemm.cc index 4057b63f0c65d..ac8d1a590c250 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/math/gemm.h" +#include "core/providers/webgpu/math/gemm_vec4.h" #include @@ -197,6 +198,11 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { return Status::OK(); } + // First try vec4 optimization if possible + if (CanApplyGemmVec4(A, B)) { + return ApplyGemmVec4(A, B, C, transA_, transB_, alpha_, beta_, context, Y); + } + // WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty. bool need_handle_matmul = A_shape.Size() > 0 && B_shape.Size() > 0; bool need_handle_bias = C && beta_; @@ -220,14 +226,12 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { .AddOutputs({{Y, ProgramTensorMetadataDependency::Type}}) .SetDispatchGroupSize(num_tile_n * num_tile_m) .SetWorkgroupSize(TILE_SIZE, TILE_SIZE) - .AddUniformVariables({ - {static_cast(num_tile_n)}, // num_tile_n - {static_cast(M)}, // M - {static_cast(N)}, // N - {static_cast(K)}, // K - {alpha_}, // alpha - {beta_} // beta - }); + .AddUniformVariables({{num_tile_n}, + {M}, + {N}, + {K}, + {alpha_}, + {beta_}}); return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/math/gemm_vec4.cc b/onnxruntime/core/providers/webgpu/math/gemm_vec4.cc new file mode 100644 index 0000000000000..6ba93df8247d2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/gemm_vec4.cc @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/gemm_vec4.h" + +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { + +void GemmVec4Program::MatMulReadFnSource(ShaderHelper& shader) const { + // We can’t treat `output_value_t` as the type of A and B, because output might not be a vec4, while A or B is. + const std::string data_type = "output_element_t"; + const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); + + shader.AdditionalImplementation() + << "fn mm_readA(row: u32, col: u32, total_rows: u32, total_cols: u32) -> " << type_string << " { \n" + << " if(col < total_cols && row < total_rows) {\n" + << " return A[row * total_cols + col];\n" + << " } else {\n" + << " return " << type_string << "(0);\n" + << " }\n" + << "}\n\n"; + + shader.AdditionalImplementation() + << "fn mm_readB(row: u32, col: u32, total_rows: u32, total_cols: u32) -> " << type_string << "{ \n" + << " if(col < total_cols && row < total_rows) {\n" + << " return B[row * total_cols + col];\n" + << " } else {\n" + << " return " << type_string << "(0);\n" + << " }\n" + << "}\n\n"; +} + +void GemmVec4Program::MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output) const { + shader.AdditionalImplementation() + << "fn mm_write(row: u32, col: u32, valuesIn: output_value_t) { \n"; + + if (output_components_ == 1) { + shader.AdditionalImplementation() << " let total_cols = uniforms.N; \n"; + } else { + shader.AdditionalImplementation() << " let total_cols = uniforms.N4; \n"; + } + + shader.AdditionalImplementation() << "var values = valuesIn; \n" + << "if(col < total_cols && row < uniforms.M) { \n"; + if (need_handle_bias_) { + const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << " values += output_element_t(uniforms.beta) * "; + // We can be allowed to use broadcasting only when both components are equal. + // There is only one case for c_components_ is not equal output_components_. + // I.g. the former is `1` and the latter is `4`. + // That means the shape of C is either {M,1} or {1,1} + if (c_components_ == output_components_) { + shader.AdditionalImplementation() << "output_value_t(" + << C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(row, col)", output)) << ");\n"; + } else if (c_is_scalar_) { + shader.AdditionalImplementation() << "output_value_t(C[0]);\n"; + } else { + shader.AdditionalImplementation() << "output_value_t(C[row]);\n"; + } + } + shader.AdditionalImplementation() << " output[row * total_cols + col] = values;\n" + << " }\n" + << "}\n"; +} + +Status GemmVec4Program::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + // We can’t treat `output_value_t` as the type of A and B, because output might not be a vec4, while A or B is. + const std::string data_type = "output_element_t"; + const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); + + shader.MainFunctionBody() << " var values = " << type_string << "(0);\n\n" + << " let tile_col_start = (workgroup_idx % uniforms.num_tile_n) * 8u;\n" + << " let tile_row_start = (workgroup_idx / uniforms.num_tile_n) * 32u;\n"; + + if (need_handle_matmul_) { + shader.AddInput("A", ShaderUsage::UseUniform); + shader.AddInput("B", ShaderUsage::UseUniform); + + MatMulReadFnSource(shader); + + // Add shared memory arrays for tiling + shader.AdditionalImplementation() << "var tile_a: array, 32 >;\n " + << "var tile_b: array, 32 >;\n "; + + shader.MainFunctionBody() + << " var k_start_a = 0u;\n" + << " var k_start_b = 0u;\n\n" + << " let num_tiles = (uniforms.K + 32 - 1) / 32;\n"; + + // Main loop for matrix multiplication + shader.MainFunctionBody() + << " for (var t = 0u; t < num_tiles; t = t + 1u) {\n"; + // Load TILE_A + if (transA_) { + shader.MainFunctionBody() << R"TILE_A( + var row = k_start_a + (local_idx / 8u); + var col = tile_row_start/4 + local_idx % 8u; + tile_a[local_idx / 8u][local_idx % 8u] = mm_readA(row, col, uniforms.K, uniforms.M4); + )TILE_A"; + } else { + shader.MainFunctionBody() << R"TILE_A( + var row = tile_row_start + local_idx / 8u; + var col = k_start_a + (local_idx % 8u); + tile_a[local_idx / 8u][local_idx % 8u] = mm_readA(row, col, uniforms.M, uniforms.K4); + )TILE_A"; + } + // Load TILE_B + if (transB_) { + shader.MainFunctionBody() << R"TILE_B( + row = tile_col_start * 4 + (local_idx / 8u); + col = k_start_b + (local_idx % 8u); + // load 1 vec4 into tile_b + tile_b[local_idx / 8u][local_idx % 8u] = mm_readB(row, col, uniforms.N, uniforms.K4); + )TILE_B"; + } else { + shader.MainFunctionBody() << R"TILE_B( + row = k_start_b + (local_idx / 8u); + col = tile_col_start + (local_idx % 8u); + // load 1 vec4 into tile_b + tile_b[local_idx / 8u][local_idx % 8u] = mm_readB(row, col, uniforms.K, uniforms.N4); + )TILE_B"; + } + + shader.MainFunctionBody() << " workgroupBarrier();\n\n"; + + if (transA_) { + shader.MainFunctionBody() << "k_start_a = k_start_a + 32u; \n"; + } else { + shader.MainFunctionBody() << "k_start_a = k_start_a + 8u; \n"; + } + + if (transB_) { + shader.MainFunctionBody() << "k_start_b = k_start_b + 8u; \n"; + } else { + shader.MainFunctionBody() << "k_start_b = k_start_b + 32u; \n"; + } + + // Calculate output according to TILE_A and TILE_B + if (transA_ && transB_) { + shader.MainFunctionBody() << R"CALC( + // Calculate 4 output for each thread + // We read 32 vec4 from tile_a and 32 vec4 from tile_b in total. + for (var i = 0u; i < 32; i = i + 4u) { + let a1 = tile_a[i][local_idx / 32u]; + let a2 = tile_a[i + 1u][local_idx / 32u]; + let a3 = tile_a[i + 2u][local_idx / 32u]; + let a4 = tile_a[i + 3u][local_idx / 32u]; + let b1 = tile_b[(local_idx % 8) * 4][i / 4u]; + let b2 = tile_b[(local_idx % 8) * 4 + 1u][i / 4u]; + let b3 = tile_b[(local_idx % 8) * 4 + 2u][i / 4u]; + let b4 = tile_b[(local_idx % 8) * 4 + 3u][i / 4u]; + + var vec_idx = local_idx / 8u % 4; + + values[0] += a1[vec_idx] * b1[0] + a2[vec_idx] * b1[1] + a3[vec_idx] * b1[2] + a4[vec_idx] * b1[3]; + values[1] += a1[vec_idx] * b2[0] + a2[vec_idx] * b2[1] + a3[vec_idx] * b2[2] + a4[vec_idx] * b2[3]; + values[2] += a1[vec_idx] * b3[0] + a2[vec_idx] * b3[1] + a3[vec_idx] * b3[2] + a4[vec_idx] * b3[3]; + values[3] += a1[vec_idx] * b4[0] + a2[vec_idx] * b4[1] + a3[vec_idx] * b4[2] + a4[vec_idx] * b4[3]; + } + )CALC"; + } else if (transA_ && !transB_) { + shader.MainFunctionBody() << R"CALC( + // Calculate 4 output for each thread + // We read 32 vec4 from tile_a and 32 vec4 from tile_b in total. + for (var i = 0u; i < 32; i = i + 1u) { + let a = tile_a[i][local_idx / 32u]; + let b = tile_b[i][local_idx % 8u]; + values += a[(local_idx / 8u) % 4] * b; + })CALC"; + } else if (!transA_ && transB_) { + shader.MainFunctionBody() << R"CALC( + for (var i = 0u; i < 32; i = i + 4u) { + let a = tile_a[local_idx / 8u][i/4u]; + let b1 = tile_b[(local_idx % 8) * 4][i / 4u]; + let b2 = tile_b[(local_idx % 8) * 4 + 1u][i / 4u]; + let b3 = tile_b[(local_idx % 8) * 4 + 2u][i / 4u]; + let b4 = tile_b[(local_idx % 8) * 4 + 3u][i / 4u]; + + values += vec4( + dot(a, b1), + dot(a, b2), + dot(a, b3), + dot(a, b4) + ); + } + )CALC"; + } else { + shader.MainFunctionBody() << R"CALC( + for (var i = 0u; i < 32; i = i + 4u) { + let a = tile_a[local_idx / 8u][i/4u]; + let b1 = tile_b[i][local_idx % 8u]; + let b2 = tile_b[i+1][local_idx % 8u]; + let b3 = tile_b[i+2][local_idx % 8u]; + let b4 = tile_b[i+3][local_idx % 8u]; + + values += a.x * b1 + a.y * b2 + a.z * b3 + a.w * b4; + } + )CALC"; + } + shader.MainFunctionBody() << " workgroupBarrier();\n" + << " }\n\n"; + + // Calculate alpha + if (alpha_ != 1.0f) { + shader.MainFunctionBody() << " values = output_element_t(uniforms.alpha) * values;\n"; + } + } + + MatMulWriteFnSource(shader, output); + shader.MainFunctionBody() << " let m = tile_row_start + local_idx / 8u;\n" + << " let n = tile_col_start + local_idx % 8u;\n\n"; + + // Write output + if (output_components_ == 1) { + shader.MainFunctionBody() << " for (var i = 0u; i < 4u; i = i + 1u) {\n" + << " mm_write(m, 4 * n + i, values[i]);\n" + << " }\n"; + } else { + shader.MainFunctionBody() << " mm_write(m, n, values);\n"; + } + + return Status::OK(); +} + +bool CanApplyGemmVec4(const Tensor* a, + const Tensor* b) { + const auto& a_shape = a->Shape(); + const auto& b_shape = b->Shape(); + + // When the number of columns in A and B is divisible by 4, we apply vec4 optimization to A and B. + // However, this doesn't necessarily mean that C and Y will use vec4. + // For example, C/output won't be vec4 if B is transposed and N is not divisible by 4. + // Also, C won't use vec4 when it's a scalar. + // The code would be simpler if we avoided vec4 optimization for C/output. + // But to maximize performance, we still apply vec4 when possible — even though it adds some complexity. + // I've added detailed comments explaining this logic. + // See MatMulReadFnSource and MatMulWriteFnSource, especially the parts related to broadcasting. + return a_shape[1] % 4 == 0 && b_shape[1] % 4 == 0; +} + +Status ApplyGemmVec4(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context, + Tensor* y) { + const auto& a_shape = a->Shape(); + const auto& b_shape = b->Shape(); + + uint32_t M = onnxruntime::narrow(transA ? a_shape[1] : a_shape[0]); + uint32_t K = onnxruntime::narrow(transA ? a_shape[0] : a_shape[1]); + uint32_t N = onnxruntime::narrow(transB ? b_shape[0] : b_shape[1]); + + // WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty. + bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0; + bool need_handle_bias = c && beta; + + int c_components = 4; + bool c_is_scalar = false; + + // We use vec4 for C when its last dimension equals N and N is divisible by 4. + if (need_handle_bias) { + const auto& c_shape = c->Shape(); + int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1]; + c_components = (c_last_dim == N && N % 4 == 0) ? 4 : 1; + c_is_scalar = c_shape.Size() == 1; + } + + // We use vec4 for Y when N is divisible by 4. + const int output_components = N % 4 == 0 ? 4 : 1; + + GemmVec4Program program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components}; + + const int components = 4; + + if (need_handle_matmul) { + program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } + + if (need_handle_bias) { + program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components}); + } + + const uint32_t TILE_SIZE = 32; + const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE; + + program.CacheHint(alpha, transA, transB, c_is_scalar) + .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) + .SetDispatchGroupSize(num_tile_n * num_tile_m) + .SetWorkgroupSize(256, 1, 1) + .AddUniformVariables({{num_tile_n}, + {M}, + {N}, + {K}, + {M / 4}, + {N / 4}, + {K / 4}, + {alpha}, + {beta}}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/gemm_vec4.h b/onnxruntime/core/providers/webgpu/math/gemm_vec4.h new file mode 100644 index 0000000000000..ae7be49ce9218 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/gemm_vec4.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +class GemmVec4Program final : public Program { + public: + GemmVec4Program(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components) + : Program{"GemmVec4"}, + transA_{transA}, + transB_{transB}, + alpha_{alpha}, + need_handle_bias_{need_handle_bias}, + need_handle_matmul_{need_handle_matmul}, + c_components_(c_components), + c_is_scalar_(c_is_scalar), + output_components_(output_components) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + void MatMulReadFnSource(ShaderHelper& shader) const; + void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output) const; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"num_tile_n", ProgramUniformVariableDataType::Uint32}, + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"M4", ProgramUniformVariableDataType::Uint32}, + {"N4", ProgramUniformVariableDataType::Uint32}, + {"K4", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"beta", ProgramUniformVariableDataType::Float32}); + + private: + bool transA_; + bool transB_; + float alpha_; + bool need_handle_bias_; + bool need_handle_matmul_; + int c_components_; + bool c_is_scalar_ = false; + int output_components_; +}; + +Status ApplyGemmVec4(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context, + Tensor* y); + +bool CanApplyGemmVec4(const Tensor* a, + const Tensor* b); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index cdd3909874e7f..4499f4a2432b9 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -232,7 +232,7 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector, rowPerThread>;\n"; + << " var acc: array, rowPerThread>;\n"; if (sequentially_access_by_threads) { shader.MainFunctionBody() << "let localRow = i32(local_id.y);\n" @@ -277,7 +278,7 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, << " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << "];\n" << " }\n" << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" - << " let ACached = " << (transpose_a ? "mm_Asub[k][localCol + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n" + << " let ACached = " << (transpose_a ? "mm_Asub[k][localRow + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n" << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" << " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n" << " ACached * BCached[innerCol];\n" diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 0edad3eebe2ea..467e2f3f3a5ce 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -198,7 +198,7 @@ Status Conv::ComputeInternal(ComputeContext& context .AddInputs({{matmul_inputs[0], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[0], a_components), int(a_components)}, {matmul_inputs[1], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[1], components), int(components)}}); if (has_bias) { - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, bias->Shape(), components}); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, ReduceShapeByComponents(bias->Shape(), components), components}); } program .AddOutputs({{output, ProgramTensorMetadataDependency::None, ReduceShapeByComponents(matmul_output_shape, components), int(components)}}) diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc index 1e5e52215b53f..311808fdd9e09 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc @@ -159,7 +159,7 @@ Status Conv2dMMProgram::GenerateShaderCode(ShaderHelper& shader) const { << declaration_functions.str() << Conv2dCommonSnippet(x, w, activation_, "x_element_t", element_size_[0], element_size_[1], element_size_[2]); std::string data_type = "x_element_t"; - return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, false, tile_inner_, false, 0, sequentially_access_by_threads_); + return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, !is_channels_last_, tile_inner_, /* split_t = */ false, 0, sequentially_access_by_threads_); } Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector& input_output_shapes) { diff --git a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc index 0cab454a5a530..f3bccec4872fc 100644 --- a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc @@ -13,23 +13,25 @@ namespace onnxruntime { namespace webgpu { Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AdditionalImplementation() << "var workgroup_shared_sum : array;\n" - << "var workgroup_shared_squared_sum : array;\n" + shader.AdditionalImplementation() << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n" + << "var workgroup_shared_sum : array;\n" + << "var workgroup_shared_squared_sum : array;\n" << "const workgroup_size = " << workgroup_size_ << ";\n"; + shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n" << " let channel = workgroup_idx % uniforms.x_shape[1];\n" << " let hight = uniforms.x_shape[2];\n" << " // initialize workgroup memory<< \n" - << " var sum = x_value_t(0);\n" - << " var squared_sum = x_value_t(0);\n" + << " var sum = f32_val_t(0);\n" + << " var squared_sum = f32_val_t(0);\n" << " for (var h = local_idx; h < hight; h += workgroup_size) {\n" << " let indices = x_indices_t(batch, channel, h);\n" - << " let value =" << input.GetByIndices("indices") << ";\n" + << " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n" << " sum += value;\n" << " squared_sum += value * value;\n" << " }\n" @@ -44,12 +46,12 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) << " workgroupBarrier();\n" << " }\n" << " if (local_idx == 0) {\n" - << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n" - << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / x_element_t(hight * " << components_ << ");\n" - << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + x_element_t(" << std::to_string(epsilon_) << "));\n" - << " let channel_scale = inv_std_dev * " << scale.GetByOffset("channel") << ";\n" - << " let channel_shift = " << bias.GetByOffset("channel") << " - sum_final * channel_scale;\n" - << " " << output.SetByOffset("workgroup_idx", "output_value_t(channel_scale, channel_shift)") << ";\n" + << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" + << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" + << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n" + << " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n" + << " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n" + << " " << output.SetByOffset("workgroup_idx", "output_value_t(output_element_t(channel_scale), output_element_t(channel_shift))") << ";\n" << " }\n"; return Status::OK(); } @@ -110,7 +112,7 @@ Status InstanceNormProgramNHWC::GenerateShaderCode(ShaderHelper& shader) const { << "let input_value = " << input.GetByOffset("global_idx") << ";\n"; if (components_ > 1) { shader.MainFunctionBody() << "for (var i : u32 = 0; i < uniforms.components; i = i + 1) {\n" - << " let scale_sift = " << channel_scale_shift.GetByOffset("scale_offset + i") << ";\n" + << " let scale_sift = " << channel_scale_shift.GetByOffset("uniforms.components * scale_offset + i") << ";\n" << " scale[i] = input_element_t(scale_sift.x);\n" << " shift[i] = input_element_t(scale_sift.y);\n" << "}\n"; diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 89f547481b6e4..bd64416401fad 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -105,7 +105,10 @@ constexpr std::string_view ProgramVariableDataTypeName[] = { "i8x4", // Int8x4 "i8x8", // Int8x8 "i8x16", // Int8x16 + "u4x8", // Uint4x8 + "i4x8", // Int4x8 }; + std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; return os; @@ -135,8 +138,12 @@ int NumberOfComponents(ProgramVariableDataType type) { case ProgramVariableDataType::Int8x4: return 4; case ProgramVariableDataType::Uint8x8: + case ProgramVariableDataType::Int8x8: + case ProgramVariableDataType::Uint4x8: + case ProgramVariableDataType::Int4x8: return 8; case ProgramVariableDataType::Uint8x16: + case ProgramVariableDataType::Int8x16: return 16; default: return -1; @@ -146,10 +153,6 @@ int NumberOfComponents(ProgramVariableDataType type) { ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { if (component == 1) { switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return ProgramVariableDataType::Uint8x4; // shader needs to be aware that only 1 value is valid - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return ProgramVariableDataType::Int8x4; // shader needs to be aware that only 1 value is valid case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return ProgramVariableDataType::Float32; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: @@ -201,6 +204,10 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return ProgramVariableDataType::Uint8x8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + return ProgramVariableDataType::Uint4x8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return ProgramVariableDataType::Int4x8; default: return ProgramVariableDataType::InvalidType; } @@ -259,6 +266,15 @@ ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency } } +ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, ProgramInput::FlattenTag, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{} { + override_shape = {(tensor->Shape().Size() + component - 1) / component}; +} + ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) : tensor{tensor}, dependency{dependency}, @@ -273,6 +289,7 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep : tensor{tensor}, dependency{dependency}, var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + is_atomic{false}, use_override_shape{component > 1}, override_shape{} { if (use_override_shape) { @@ -280,10 +297,19 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep } } +ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, ProgramOutput::AtomicTag) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType())}, + is_atomic{true}, + use_override_shape{false}, + override_shape{} {} + ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) : tensor{tensor}, dependency{dependency}, var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + is_atomic{false}, use_override_shape{true}, override_shape{override_shape} {} diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 3b0acfa7d0d35..705e6cd2fdb37 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -201,6 +201,9 @@ enum class ProgramVariableDataType { Int8x4, Int8x8, Int8x16, + Uint4x8, + Int4x8, + // if you add a new type here, you also need to update ProgramVariableDataTypeName }; #ifndef NDEBUG std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); @@ -211,8 +214,15 @@ int NumberOfComponents(ProgramVariableDataType type); ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); struct ProgramInput { + private: + struct FlattenTag {}; + + public: + constexpr static const FlattenTag Flatten{}; + ProgramInput(const Tensor* tensor); ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, FlattenTag, int component = 1); ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); const Tensor* tensor; @@ -223,13 +233,21 @@ struct ProgramInput { }; struct ProgramOutput { + private: + struct AtomicTag {}; + + public: + constexpr static const AtomicTag Atomic{}; + ProgramOutput(Tensor* tensor); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1); + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, AtomicTag); ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component); Tensor* tensor; ProgramTensorMetadataDependency dependency; ProgramVariableDataType var_type; + bool is_atomic; bool use_override_shape; TensorShape override_shape; }; diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index 866b1debf6dc8..e2b5d73168935 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -121,8 +121,11 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { auto* output_tensor = context.Output(0, x_shape); int64_t x_scale_rank = x_scale->Shape().NumDimensions(); - bool packed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - bool is_signed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + // Currently only INT8, UINT8, and INT32 are registered. + auto x_type = x->GetElementType(); + + bool packed = x_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + bool is_signed = x_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions(); int max_components = GetMaxComponents(x_size); @@ -138,14 +141,14 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { bool use_components = per_layer && (!packed || max_components == 4); int components = use_components ? max_components : 1; - int input_component = use_components && !packed ? max_components : 1; + int input_component = use_components ? max_components : 1; DequantizeLinearProgram program{packed, is_signed, per_layer, per_axis, x_zeropoint != nullptr}; program - .AddInputs({{x, ProgramTensorMetadataDependency::TypeAndRank, input_component}}) + .AddInputs({{x, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, packed ? 4 : input_component}}) .AddInputs({{x_scale, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, components}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank, components}) .SetDispatchGroupSize((x_size / components + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({{static_cast(axis)}}) .AddUniformVariables({{static_cast(block_size_)}}) @@ -153,7 +156,7 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { .CacheHint(std::to_string(axis), std::to_string(is_signed), std::to_string(per_layer), std::to_string(per_axis), std::to_string(block_size_)); if (x_zeropoint != nullptr) { - program.AddInputs({{x_zeropoint, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddInputs({{x_zeropoint, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, packed ? 4 : 1}}); } return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index bac360c4c270e..59855e6117641 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -123,7 +123,12 @@ const ShaderIndicesHelper& ShaderHelper::AddIndices(const std::string& name, Sha #ifndef NDEBUG // if debug build namespace { // Validate if the tensor element type matches the program variable data type -Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type, bool is_atomic = false) { + if (is_atomic) { + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32, + "Unexpected program variable type ", int(var_type), " for atomic variable"); + } + switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || @@ -174,6 +179,14 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va var_type == ProgramVariableDataType::Int8x16, "Unexpected program variable type ", int(var_type), " for int8 tensor"); break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int4x8, + "Unexpected program variable type ", int(var_type), " for int4 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint4x8, + "Unexpected program variable type ", int(var_type), " for uint4 tensor"); + break; default: ORT_RETURN_IF(true, "Unsupported data type: ", element_type); // todo: add int4/uint4 @@ -237,7 +250,7 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar return Status::OK(); } Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariableHelper& var) const { - ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_, output.is_atomic)); ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), output.use_override_shape, output.use_override_shape ? output.override_shape : output.tensor->Shape(), @@ -400,12 +413,22 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // // Input/output variables // - size_t variable_count = 0; - for (const auto& input : input_vars_) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; + for (size_t i = 0; i < input_vars_.size(); ++i) { + const auto& input = input_vars_[i]; + ss << "@group(0) @binding(" << i << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; } - for (const auto& output : output_vars_) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; + for (size_t i = 0; i < output_vars_.size(); ++i) { + const auto& output = output_vars_[i]; + bool is_atomic = program_.Outputs()[i].is_atomic; + ss << "@group(0) @binding(" << input_vars_.size() + i << ") var " << output->name_ << ": array<"; + if (is_atomic) { + ss << "atomic<"; + } + ss << output->StorageType(); + if (is_atomic) { + ss << ">"; + } + ss << ">;\n"; } // @@ -512,7 +535,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << "\n};\n" "@group(0) @binding(" - << variable_count << ") var uniforms: Uniforms;\n"; + << input_vars_.size() + output_vars_.size() << ") var uniforms: Uniforms;\n"; } // diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 502d03c2c2dd8..79175370529e0 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -33,6 +33,10 @@ constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = { "vec2", // Uint8x8 "vec4", // Uint8x16 "u32", // Int8x4 + "vec2", // Int8x8 + "vec4", // Int8x16 + "u32", // Uint4x8 + "u32", // Int4x8 }; constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY); @@ -55,7 +59,11 @@ constexpr static const std::string_view VALUE_TYPE_ARRAY[] = { "u32", // Uint8x4 (u32 as 4 elements of uint8) "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) - "i32", // Int8x4 + "u32", // Int8x4 (u32 as 4 elements of int8) + "vec2", // Int8x8 (vec2 as 2x4 elements of int8) + "vec4", // Int8x16 (vec4 as 4x4 elements of int8) + "u32", // Uint4x8 + "u32", // Int4x8 }; constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY); @@ -81,6 +89,8 @@ constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = { "i32", // Int8x4 "i32", // Int8x8 "i32", // Int8x16 + "u32", // Uint4x8 + "i32", // Int4x8 }; constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY); diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 3e831f9853451..59e62043fd0c0 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -43,7 +43,9 @@ Status Expand::ComputeInternal(ComputeContext& context) const { const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 : 1; uint32_t data_size = onnxruntime::narrow(output_shape.Size() / components_o); - + if (data_size == 0) { + return Status::OK(); + } ExpandProgram program{}; program .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index f5f108121cb8d..928e48d78d7e5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -699,7 +699,7 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index e02d9266e8a0e..86eb57f99f3b3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -5,11 +5,15 @@ #include #include "core/common/common.h" +#include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" namespace onnxruntime { namespace webgpu { +/** + * Returns the maximum number of components `N` to be used as `vecN` for the given size. + */ inline int GetMaxComponents(int64_t size) { if (size % 4 == 0) { return 4; @@ -19,6 +23,11 @@ inline int GetMaxComponents(int64_t size) { return 1; } +/** + * Returns a string representing a WGSL expression that sums the components of a value T. + * + * T can be a scalar S, vec2 or vec4. + */ inline std::string SumVector(std::string x, int components) { switch (components) { case 1: @@ -49,5 +58,36 @@ inline std::string MakeScalarOrVectorType(int components, std::string_view data_ TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components); +/** + * Create a reshaped tensor from an existing tensor. + * + * The specified new shape must have the same number of elements as the original tensor. + * + * The new tensor is a "view" of the original tensor. It uses the same data of the original tensor. + * The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor. + */ +inline Tensor CreateTensorView(const Tensor& tensor, const TensorShape& new_shape) { + ORT_ENFORCE(tensor.Shape().Size() == new_shape.Size(), "Cannot reshape tensor ", tensor.Shape().ToString(), " to ", new_shape.ToString()); + return {tensor.DataType(), new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; +} + +/** + * Create a reinterpreted tensor from an existing tensor with a new data type and shape. + * + * The new data type and shape must match the original tensor's storage size. + * + * The new tensor is a "view" of the original tensor. It uses the same data of the original tensor. + * The new tensor does not take or share ownership of the underlying data. The original tensor must outlive the new tensor. + */ +inline Tensor CreateTensorView(const Tensor& tensor, MLDataType new_data_type, const TensorShape& new_shape) { + auto byte_size = Tensor::CalculateTensorStorageSize(tensor.DataType(), tensor.Shape()); + auto new_byte_size = Tensor::CalculateTensorStorageSize(new_data_type, new_shape); + ORT_ENFORCE(byte_size == new_byte_size, + "Cannot reshape tensor ", tensor.Shape().ToString(), " to ", new_shape.ToString(), + " with data type ", DataTypeImpl::ToString(new_data_type), ". The byte size of the original tensor is ", + byte_size, " and the byte size of the new tensor is ", new_byte_size); + return {new_data_type, new_shape, const_cast(tensor.DataRaw()), tensor.Location()}; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index b794ff6a63a6c..ae52e2cd5d936 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -52,7 +52,7 @@ class BaseOpBuilder : public IOpBuilder { // We still set the mininal supported opset to 1 as we couldn't // get the model opset version at this stage. virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 21; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 23; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index d28036aa656d8..d02dd61460f60 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -93,8 +93,9 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build uint32_t num_heads = helper.Get("num_heads", 0); uint32_t rotary_embedding_dim = helper.Get("rotary_embedding_dim", 0); - // The input is either with 3D tensor shape (batch_size, sequence_length, hidden_size) or - // 4D tensor shape (batch_size, num_heads, sequence_length, head_size) + // The input can be: + // - 3D: [batch_size, sequence_length, hidden_size] + // - 4D: [batch_size, num_heads, sequence_length, head_size] const uint32_t batch_size = static_cast(input_shape[0]); const uint32_t sequence_length = input_is_4d ? static_cast(input_shape[2]) : static_cast(input_shape[1]); @@ -109,9 +110,19 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build rotary_embedding_dim = head_size; } - // First ensure the input has shape (batch_size, num_heads, sequence_length, head_size). - if (!input_is_4d) { - const std::vector new_shape{batch_size, num_heads, sequence_length, head_size}; + emscripten::val transpose_options = emscripten::val::object(); + + // Ensure the input is reshaped to: [batch_size, sequence_length, num_heads, head_size]. + if (input_is_4d) { + // The input is already in 4D shape, but we need to ensure the order is + // [batch_size, sequence_length, num_heads, head_size] to make it broadcastable with + // the coming mul operator with cos_cache and sin_cache. + const std::vector permutation{0, 2, 1, 3}; + transpose_options.set("label", node_name + "_transpose_input"); + transpose_options.set("permutation", emscripten::val::array(permutation)); + input = wnn_builder.call("transpose", input, transpose_options); + } else { + const std::vector new_shape{batch_size, sequence_length, num_heads, head_size}; emscripten::val reshape_input_options = emscripten::val::object(); reshape_input_options.set("label", node_name + "_reshape_input"); input = wnn_builder.call( @@ -276,12 +287,21 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build output = wnn_builder.call("concat", concat_inputs, 3, concat_back_input_options); } - // Reshape the output to the original shape. The output shape is the same as the input shape. - const std::vector output_shape = GetNarrowedIntfromInt64(input_shape); - emscripten::val reshape_output_options = emscripten::val::object(); - reshape_output_options.set("label", node_name + "_reshape_output"); - output = wnn_builder.call( - "reshape", output, emscripten::val::array(output_shape), reshape_output_options); + if (input_is_4d) { + // The output is in 4D shape, we need to transpose it back to the original shape. + // Reuse the transpose_options' permutation because the original permutation also + // happens to be its own inverse. (inserve({0, 2, 1, 3} == {0, 2, 1, 3}) + transpose_options.set("label", node_name + "_transpose_output"); + output = wnn_builder.call("transpose", output, transpose_options); + } else { + // The output is in 3D shape, we need to reshape it back to the original shape. + // The output shape is same as the input shape. + const std::vector output_shape = GetNarrowedIntfromInt64(input_shape); + emscripten::val reshape_output_options = emscripten::val::object(); + reshape_output_options.set("label", node_name + "_reshape_output"); + output = wnn_builder.call( + "reshape", output, emscripten::val::array(output_shape), reshape_output_options); + } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 28de183fde405..3242be817881a 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -19,10 +19,17 @@ struct OrtKeyValuePairs { Sync(); } void Add(const char* key, const char* value) { - return Add(std::string(key), std::string(value)); + // ignore if either are nullptr. + if (key && value) { + Add(std::string(key), std::string(value)); + } } void Add(const std::string& key, const std::string& value) { + if (key.empty()) { // ignore empty keys + return; + } + auto iter_inserted = entries.insert({key, value}); bool inserted = iter_inserted.second; if (inserted) { @@ -37,6 +44,10 @@ struct OrtKeyValuePairs { // we don't expect this to be common. reconsider using std::vector if it turns out to be. void Remove(const char* key) { + if (key == nullptr) { + return; + } + auto iter = entries.find(key); if (iter != entries.end()) { auto key_iter = std::find(keys.begin(), keys.end(), iter->first.c_str()); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index ad7eb9cdfff25..d4b7ec1ff99fe 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -3,8 +3,11 @@ #include "core/session/environment.h" +#include + #include "core/common/basic_types.h" #include "core/framework/allocator_utils.h" +#include "core/framework/error_code_helper.h" #include "core/graph/constants.h" #include "core/graph/op.h" #include "core/platform/device_discovery.h" @@ -53,7 +56,7 @@ #include "orttraining/core/optimizer/graph_transformer_registry.h" #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #endif @@ -62,9 +65,9 @@ using namespace ::onnxruntime::common; using namespace ONNX_NAMESPACE; std::once_flag schemaRegistrationOnceFlag; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) ProviderInfo_CUDA& GetProviderInfo_CUDA(); -#endif // USE_CUDA +#endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) Status Environment::Create(std::unique_ptr logging_manager, std::unique_ptr& environment, @@ -350,7 +353,7 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ return CreateAndRegisterAllocator(mem_info, arena_cfg); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (provider_type == onnxruntime::kCudaExecutionProvider) { CUDAExecutionProviderInfo cuda_ep_info; GetProviderInfo_CUDA().CUDAExecutionProviderInfo__FromProviderOptions(options, cuda_ep_info); @@ -468,6 +471,28 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam return status; } +namespace { +std::vector SortDevicesByType() { + auto& devices = DeviceDiscovery::GetDevices(); + std::vector sorted_devices; + sorted_devices.reserve(devices.size()); + + const auto select_by_type = [&](OrtHardwareDeviceType type) { + for (const auto& device : devices) { + if (device.type == type) { + sorted_devices.push_back(&device); + } + } + }; + + select_by_type(OrtHardwareDeviceType_NPU); + select_by_type(OrtHardwareDeviceType_GPU); + select_by_type(OrtHardwareDeviceType_CPU); + + return sorted_devices; +} +} // namespace + Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, const std::vector& internal_factories) { if (!library_in) { @@ -482,36 +507,25 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u ORT_RETURN_IF_ERROR(instance.library->Load()); const auto& factories = instance.library->GetFactories(); + // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. + // the set of hardware devices is static so this can also be static. + const static std::vector sorted_devices = SortDevicesByType(); + for (auto* factory_ptr : factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", instance.library->RegistrationName()); auto& factory = *factory_ptr; - // for each device - for (const auto& device : DeviceDiscovery::GetDevices()) { - OrtKeyValuePairs* ep_metadata = nullptr; - OrtKeyValuePairs* ep_options = nullptr; - - if (factory.GetDeviceInfoIfSupported(&factory, &device, &ep_metadata, &ep_options)) { - auto ed = std::make_unique(); - ed->ep_name = factory.GetName(&factory); - ed->ep_vendor = factory.GetVendor(&factory); - ed->device = &device; - - if (ep_metadata) { - ed->ep_metadata = std::move(*ep_metadata); - delete ep_metadata; - } - - if (ep_options) { - ed->ep_options = std::move(*ep_options); - delete ep_options; - } - - ed->ep_factory = &factory; + std::array ep_devices{nullptr}; + size_t num_ep_devices = 0; + ORT_RETURN_IF_ERROR(ToStatus( + factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(), + ep_devices.data(), ep_devices.size(), &num_ep_devices))); - instance.execution_devices.push_back(std::move(ed)); + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] != nullptr) { // should never happen but just in case... + instance.execution_devices.emplace_back(ep_devices[i]); // take ownership } } } diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc new file mode 100644 index 0000000000000..0cac00326392c --- /dev/null +++ b/onnxruntime/core/session/ep_api.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_api.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/ort_apis.h" + +using namespace onnxruntime; +namespace OrtExecutionProviderApi { +ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ort_ep_device) { + API_IMPL_BEGIN + auto ep_device = std::make_unique(); + ep_device->device = hardware_device; + ep_device->ep_factory = ep_factory; + ep_device->ep_name = ep_factory->GetName(ep_factory); + ep_device->ep_vendor = ep_factory->GetVendor(ep_factory); + + if (ep_metadata) { + ep_device->ep_metadata = *ep_metadata; + } + + if (ep_options) { + ep_device->ep_options = *ep_options; + } + + *ort_ep_device = ep_device.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) { + delete device; +} + +static constexpr OrtEpApi ort_ep_api = { + // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, + // and no functions can be removed (the implementation needs to change to return an error). + + &OrtExecutionProviderApi::CreateEpDevice, + &OrtExecutionProviderApi::ReleaseEpDevice, +}; + +// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtEpApi, ReleaseEpDevice) / sizeof(void*) == 1, + "Size of version 22 API cannot change"); // initial version in ORT 1.22 + +} // namespace OrtExecutionProviderApi + +ORT_API(const OrtEpApi*, OrtExecutionProviderApi::GetEpApi) { + return &ort_ep_api; +} diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h new file mode 100644 index 0000000000000..23cd31cbdd861 --- /dev/null +++ b/onnxruntime/core/session/ep_api.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +namespace OrtExecutionProviderApi { +// implementation that returns the API struct +ORT_API(const OrtEpApi*, GetEpApi); + +ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + +ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device); +} // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index 1626d9c091893..23c25b4e7befb 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,12 +16,14 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } - static bool ORT_API_CALL GetDeviceInfoIfSupported(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_device_metadata, - OrtKeyValuePairs** ep_options_for_device) { - return static_cast(this_ptr)->GetDeviceInfoIfSupported(device, ep_device_metadata, - ep_options_for_device); + static OrtStatus* ORT_API_CALL GetSupportedDevices(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) { + return static_cast(this_ptr)->GetSupportedDevices(devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices); } static OrtStatus* ORT_API_CALL CreateEp(OrtEpFactory* this_ptr, diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index 71774e11a7246..fd907302b6b8d 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,25 +14,27 @@ namespace onnxruntime { using Forward = ForwardToFactory; EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, - IsSupportedFunc&& is_supported_func, + GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, - is_supported_func_{std::move(is_supported_func)}, + get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; - OrtEpFactory::GetDeviceInfoIfSupported = Forward::GetDeviceInfoIfSupported; + OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; } -bool EpFactoryInternal::GetDeviceInfoIfSupported(const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_device_metadata, - OrtKeyValuePairs** ep_options_for_device) const { - return is_supported_func_(device, ep_device_metadata, ep_options_for_device); +OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) { + return get_supported_func_(this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); } OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, @@ -57,7 +59,7 @@ OrtStatus* EpFactoryInternal::CreateIExecutionProvider(const OrtHardwareDevice* "EpFactoryInternal currently only supports one device at a time."); } - return create_func_(devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); + return create_func_(this, devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); } void EpFactoryInternal::ReleaseEp(OrtEp* /*ep*/) { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index cfe3685e3e8e6..2dcc769ec635e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -16,26 +16,34 @@ struct SessionOptions; class EpFactoryInternal : public OrtEpFactory { public: - using IsSupportedFunc = std::function; - - using CreateFunc = std::function; + + using CreateFunc = std::function* ep)>; EpFactoryInternal(const std::string& ep_name, const std::string& vendor, - IsSupportedFunc&& is_supported_func, + GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const { return ep_name_.c_str(); } const char* GetVendor() const { return vendor_.c_str(); } - bool GetDeviceInfoIfSupported(_In_ const OrtHardwareDevice* device, - _Out_ OrtKeyValuePairs** ep_device_metadata, - _Out_ OrtKeyValuePairs** ep_options_for_device) const; + OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); // we don't implement this. CreateIExecutionProvider should be used. OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -55,10 +63,10 @@ class EpFactoryInternal : public OrtEpFactory { void ReleaseEp(OrtEp* ep); private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const IsSupportedFunc is_supported_func_; // function to check if the device is supported - const CreateFunc create_func_; // function to create the EP instance + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const GetSupportedFunc get_supported_func_; // function to return supported devices + const CreateFunc create_func_; // function to create the EP instance std::vector> eps_; // EP instances created by this factory }; diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index 0684e358b93e9..c515195c7e6bf 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -3,11 +3,13 @@ #include "core/session/ep_library_internal.h" +#include "core/framework/error_code_helper.h" #include "core/framework/session_options.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/ep_api.h" #include "core/session/ort_apis.h" #if defined(USE_DML) @@ -20,17 +22,27 @@ namespace onnxruntime { std::unique_ptr EpLibraryInternal::CreateCpuEp() { - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** /*ep_options*/) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - return true; + const auto get_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; }; - const auto create_cpu_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_cpu_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, @@ -49,33 +61,47 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } #if defined(USE_DML) std::unique_ptr EpLibraryInternal::CreateDmlEp() { static const std::string ep_name = kDmlExecutionProvider; - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** ep_options) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with - // a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - if (auto it = device->metadata.entries.find("DxgiAdapterNumber"); it != device->metadata.entries.end()) { - auto options = std::make_unique(); - options->Add("device_id", it->second.c_str()); - *ep_options = options.release(); + const auto is_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + std::unique_ptr ep_options; + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with + // a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + if (auto it = device.metadata.entries.find("DxgiAdapterNumber"); it != device.metadata.entries.end()) { + ep_options = std::make_unique(); + ep_options->Add("device_id", it->second.c_str()); + } + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices++]); + + if (api_status != nullptr) { + return api_status; + } } - - return true; } - return false; + return nullptr; }; - const auto create_dml_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_dml_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, @@ -106,20 +132,27 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { static const std::string ep_name = kWebGpuExecutionProvider; - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** /*ep_options*/) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // What is the correct behavior here to match the device if there are multiple GPUs? - // Should WebGPU default to picking the GPU with HighPerformanceIndex of 0? - // Or should we be setting the 'deviceId', 'webgpuInstance' and 'webgpuDevice' options for each GPU? - return true; + const auto is_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; }; - const auto create_webgpu_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_webgpu_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/ep_library_plugin.cc index 0cd03b2a4be07..3c873ec4a9aeb 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/ep_library_plugin.cc @@ -51,6 +51,8 @@ Status EpLibraryPlugin::Load() { } Status EpLibraryPlugin::Unload() { + std::lock_guard lock{mutex_}; + // Call ReleaseEpFactory for all factories and unload the library. // Current implementation assumes any error is permanent so does not leave pieces around to re-attempt Unload. if (handle_) { diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/ep_library_plugin.h index 58e95421e3d91..e2b02ccc654da 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/ep_library_plugin.h @@ -16,9 +16,9 @@ namespace onnxruntime { /// class EpLibraryPlugin : public EpLibrary { public: - EpLibraryPlugin(const std::string& registration_name, const ORTCHAR_T* library_path) + EpLibraryPlugin(const std::string& registration_name, std::filesystem::path library_path) : registration_name_{registration_name}, - library_path_{library_path} { + library_path_{std::move(library_path)} { } const char* RegistrationName() const override { diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 790a5a782de1c..73423a4744576 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -3,6 +3,7 @@ #include "core/session/ep_library_provider_bridge.h" +#include "core/common/status.h" #include "core/framework/error_code_helper.h" #include "core/framework/session_options.h" #include "core/providers/cuda/cuda_provider_options.h" @@ -13,17 +14,48 @@ namespace onnxruntime { Status EpLibraryProviderBridge::Load() { - // wrap the EpLibraryPlugin factories that were created by calling CreateEpFactories. - // use GetDeviceInfoIfSupported from the factory. + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. for (const auto& factory : ep_library_plugin_->GetFactories()) { - const auto is_supported_fn = [&factory](const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_metadata, - OrtKeyValuePairs** ep_options) -> bool { - return factory->GetDeviceInfoIfSupported(factory, device, ep_metadata, ep_options); + const auto is_supported_fn = [&factory](OrtEpFactory* ep_factory_internal, // from factory_ptrs_ + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) -> OrtStatus* { + ORT_API_RETURN_IF_ERROR(factory->GetSupportedDevices(factory, devices, num_devices, ep_devices, max_ep_devices, + num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = ep_factory_internal; + } + } + + return nullptr; }; - const auto create_fn = [this, &factory](const OrtHardwareDevice* const* devices, + const auto create_fn = [this, &factory](OrtEpFactory* /*ep_factory_internal from factory_ptrs_*/, + const OrtHardwareDevice* const* devices, const OrtKeyValuePairs* const* ep_metadata_pairs, size_t num_devices, const OrtSessionOptions* session_options, @@ -42,7 +74,6 @@ Status EpLibraryProviderBridge::Load() { factory->GetVendor(factory), is_supported_fn, create_fn); - factory_ptrs_.push_back(internal_factory.get()); internal_factory_ptrs_.push_back(internal_factory.get()); factories_.push_back(std::move(internal_factory)); @@ -52,7 +83,19 @@ Status EpLibraryProviderBridge::Load() { } Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + provider_library_->Unload(); + provider_library_ = nullptr; + return Status::OK(); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/ep_library_provider_bridge.h index 5f85192866cf4..3c7f083df227e 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/ep_library_provider_bridge.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "core/session/ep_library.h" #include "core/session/ep_factory_internal.h" @@ -44,10 +45,11 @@ class EpLibraryProviderBridge : public EpLibrary { ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibraryProviderBridge); private: + std::mutex mutex_; std::unique_ptr provider_library_; // provider bridge EP library // EpLibraryPlugin that provides the CreateEpFactories and ReleaseEpFactory implementations. - // we wrap the factories it contains to pass through GetDeviceInfoIfSupported calls, and + // we wrap the OrtEpFactory instances it contains to pass through GetSupportedDevices calls, and // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index bae355bb4e518..8ec7312cc6354 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2865,6 +2865,8 @@ Status InferenceSession::Run(const RunOptions& run_options, } #endif + reset_saturation_count(); + // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 2a395050636ba..a21388d1e9918 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -58,6 +58,8 @@ class IExecutionProvider; class IOBinding; struct Notification; +void reset_saturation_count(); + #ifdef ENABLE_TRAINING struct PartialGraphExecutionState; using OrtValueCache = InlinedHashMap; diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 466edce187a56..85ea958981e2c 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -12,13 +12,13 @@ #include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" #endif namespace onnxruntime { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); #endif @@ -58,7 +58,7 @@ static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_i } if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { data_transfer = cuda_provider_info->CreateGPUDataTransfer(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f2d03610bec1e..b5c271594055a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -33,6 +33,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/ep_api.h" #include "core/session/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" @@ -44,7 +45,7 @@ #include "core/session/ort_env.h" #include "core/session/utils.h" -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" namespace onnxruntime { @@ -330,7 +331,7 @@ std::unique_ptr GetDataTransfer(const OrtDevice& src_device, cons if (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) { return std::make_unique(); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU) { if (auto* provider_info = TryGetProviderInfo_CUDA()) { return provider_info->CreateGPUDataTransfer(); @@ -2511,7 +2512,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS return nullptr; API_IMPL_END } -#else // defined(ORT_MINIMAL_BUILD) + +ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { + return OrtExecutionProviderApi::GetEpApi(); +} + +#else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN @@ -2545,6 +2551,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS API_IMPL_END } + +ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { + fprintf(stderr, "The Execution Provider API is not supported in a minimal build.\n"); + return nullptr; +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -3012,6 +3024,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::EpDevice_EpMetadata, &OrtApis::EpDevice_EpOptions, &OrtApis::EpDevice_Device, + + &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) }; @@ -3047,7 +3061,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, EpDevice_Device) / sizeof(void*) == 314, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 315, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 76c5e7bf9c26b..0033eb0d604f2 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -588,4 +588,6 @@ ORT_API(const char*, EpDevice_EpVendor, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtKeyValuePairs*, EpDevice_EpMetadata, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device); + +ORT_API(const OrtEpApi*, GetEpApi); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_library.h b/onnxruntime/core/session/provider_bridge_library.h index 7c14fbd353a3d..a131e1ebbd122 100644 --- a/onnxruntime/core/session/provider_bridge_library.h +++ b/onnxruntime/core/session/provider_bridge_library.h @@ -9,8 +9,13 @@ namespace onnxruntime { struct Provider; +enum class ProviderLibraryPathType { + Default, + Absolute, +}; + struct ProviderLibrary { - ProviderLibrary(const ORTCHAR_T* filename, bool unload = true); + ProviderLibrary(const ORTCHAR_T* filename, bool unload = true, ProviderLibraryPathType pathType = ProviderLibraryPathType::Default); ~ProviderLibrary(); Status Load(); @@ -19,8 +24,9 @@ struct ProviderLibrary { private: std::mutex mutex_; - const ORTCHAR_T* filename_; + const ORTCHAR_T* const filename_; bool unload_; + const bool absolute_; bool initialized_{}; Provider* provider_{}; void* handle_{}; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index de56877ffe75a..7fcaee48581f6 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/common/path_string.h" #include "core/common/string_helper.h" + #include "core/framework/allocator_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/config_options.h" @@ -122,7 +123,9 @@ using EtwRegistrationManager_EtwInternalCallback = EtwRegistrationManager::EtwIn #include "core/providers/nv_tensorrt_rtx/nv_provider_factory.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" -#if !defined(ORT_MINIMAL_BUILD) && (defined(USE_TENSORRT) || defined(USE_NV)) +#if !defined(ORT_MINIMAL_BUILD) && \ + (defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) || \ + defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE)) #include "core/session/onnxruntime_session_options_config_keys.h" #endif @@ -159,6 +162,7 @@ ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); ProviderInfo_Nv* TryGetProviderInfo_Nv(); ProviderInfo_Nv& GetProviderInfo_Nv(); +ProviderInfo_OpenVINO* TryGetProviderInfo_OpenVINO(); ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { @@ -1727,8 +1731,8 @@ bool InitProvidersSharedLibrary() try { return false; } -ProviderLibrary::ProviderLibrary(const ORTCHAR_T* filename, bool unload) - : filename_{filename}, unload_{unload} { +ProviderLibrary::ProviderLibrary(const ORTCHAR_T* filename, bool unload, ProviderLibraryPathType pathType) + : filename_{filename}, unload_{unload}, absolute_{pathType == ProviderLibraryPathType::Absolute} { } ProviderLibrary::~ProviderLibrary() { @@ -1744,8 +1748,16 @@ Status ProviderLibrary::Load() { std::lock_guard lock{mutex_}; s_library_shared.Ensure(); - auto full_path = Env::Default().GetRuntimePath() + filename_; - ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, false, &handle_)); + if (absolute_) { + // If filename_ is not absolute it should not be loaded. + if (!std::filesystem::path{filename_}.is_absolute()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "An absolute path must be specified."); + } + ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(filename_, false, &handle_)); + } else { + auto full_path = Env::Default().GetRuntimePath() + filename_; + ORT_RETURN_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, false, &handle_)); + } Provider* (*PGetProvider)(); ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider)); @@ -1794,6 +1806,7 @@ void ProviderLibrary::Unload() { } } + initialized_ = false; handle_ = nullptr; provider_ = nullptr; } @@ -1916,13 +1929,23 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const return cuda_options_converted; } -std::shared_ptr CudaProviderFactoryCreator::Create(const OrtCUDAProviderOptions* provider_options) { +std::shared_ptr CudaProviderFactoryCreator::Create( + const OrtCUDAProviderOptions* provider_options) try { OrtCUDAProviderOptionsV2 cuda_options_converted = onnxruntime::OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(provider_options); return s_library_cuda.Get().CreateExecutionProviderFactory(&cuda_options_converted); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -std::shared_ptr CudaProviderFactoryCreator::Create(const OrtCUDAProviderOptionsV2* provider_options) { +std::shared_ptr CudaProviderFactoryCreator::Create( + const OrtCUDAProviderOptionsV2* provider_options) try { return s_library_cuda.Get().CreateExecutionProviderFactory(provider_options); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } std::shared_ptr RocmProviderFactoryCreator::Create(const OrtROCMProviderOptions* provider_options) { @@ -1991,25 +2014,55 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti return trt_options_converted; } -std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { +std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) try { return s_library_tensorrt.Get().CreateExecutionProviderFactory(device_id); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -std::shared_ptr TensorrtProviderFactoryCreator::Create(const OrtTensorRTProviderOptions* provider_options) { +std::shared_ptr TensorrtProviderFactoryCreator::Create( + const OrtTensorRTProviderOptions* provider_options) try { OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options); return s_library_tensorrt.Get().CreateExecutionProviderFactory(&trt_options_converted); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -std::shared_ptr TensorrtProviderFactoryCreator::Create(const OrtTensorRTProviderOptionsV2* provider_options) { +std::shared_ptr TensorrtProviderFactoryCreator::Create( + const OrtTensorRTProviderOptionsV2* provider_options) try { return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -std::shared_ptr NvProviderFactoryCreator::Create(int device_id) { +std::shared_ptr NvProviderFactoryCreator::Create(int device_id) try { return s_library_nv.Get().CreateExecutionProviderFactory(device_id); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -std::shared_ptr NvProviderFactoryCreator::Create(const ProviderOptions& provider_options) { - return s_library_nv.Get().CreateExecutionProviderFactory(&provider_options); +std::shared_ptr NvProviderFactoryCreator::Create( + const ProviderOptions& provider_options, const SessionOptions* session_options) try { + const ConfigOptions* config_options = nullptr; + if (session_options != nullptr) { + config_options = &session_options->config_options; + } + + std::array configs_array = {&provider_options, config_options}; + const void* arg = reinterpret_cast(&configs_array); + return s_library_nv.Get().CreateExecutionProviderFactory(arg); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { @@ -2055,8 +2108,8 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O } #if !BUILD_QNN_EP_STATIC_LIB -std::shared_ptr QNNProviderFactoryCreator::Create(const ProviderOptions& provider_options_map, - const SessionOptions* session_options) { +std::shared_ptr QNNProviderFactoryCreator::Create( + const ProviderOptions& provider_options_map, const SessionOptions* session_options) try { const ConfigOptions* config_options = nullptr; if (session_options != nullptr) { config_options = &session_options->config_options; @@ -2065,11 +2118,15 @@ std::shared_ptr QNNProviderFactoryCreator::Create(con std::array configs_array = {&provider_options_map, config_options}; const void* arg = reinterpret_cast(&configs_array); return s_library_qnn.Get().CreateExecutionProviderFactory(arg); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } #endif // !BUILD_QNN_EP_STATIC_LIB std::shared_ptr OpenVINOProviderFactoryCreator::Create( - const ProviderOptions* provider_options_map, const SessionOptions* session_options) { + const ProviderOptions* provider_options_map, const SessionOptions* session_options) try { // Append session options applicable for EP to EP Provider options. const ConfigOptions* config_options = nullptr; if (session_options != nullptr) { @@ -2079,18 +2136,29 @@ std::shared_ptr OpenVINOProviderFactoryCreator::Creat std::array configs_array = {provider_options_map, config_options}; const void* arg = reinterpret_cast(&configs_array); return s_library_openvino.Get().CreateExecutionProviderFactory(arg); +} catch (const std::exception& exception) { + // Will get an exception when fail to load EP library. + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } std::shared_ptr DnnlProviderFactoryCreator::Create(const OrtDnnlProviderOptions* dnnl_options) { return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); } -std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { +std::shared_ptr VitisAIProviderFactoryCreator::Create( + const ProviderOptions& provider_options) try { return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +} catch (const std::exception& exception) { + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } -ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { +ProviderInfo_OpenVINO* TryGetProviderInfo_OpenVINO() try { return reinterpret_cast(s_library_openvino.Get().GetInfo()); +} catch (const std::exception& exception) { + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; } ProviderInfo_TensorRT* TryGetProviderInfo_TensorRT() try { @@ -2413,7 +2481,7 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int device_id) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->SetCurrentGpuDeviceId(device_id); #endif @@ -2430,7 +2498,7 @@ ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int de ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, [[maybe_unused]] _In_ int* device_id) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->GetCurrentGpuDeviceId(device_id); #endif @@ -2480,7 +2548,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, std::shared_ptr factory; -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) +#if !defined(ORT_MINIMAL_BUILD) && (defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE)) auto ep_context_cache_enabled_from_provider_options = tensorrt_options->trt_dump_ep_context_model != 0; auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; @@ -2542,7 +2610,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out) { API_IMPL_BEGIN -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) auto options = std::make_unique(); *out = options.release(); return nullptr; @@ -2559,7 +2627,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, _In_reads_(num_keys) const char* const* provider_options_values, size_t num_keys) { API_IMPL_BEGIN -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) onnxruntime::ProviderOptions provider_options_map; for (size_t i = 0; i != num_keys; ++i) { if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || @@ -2583,7 +2651,11 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions, API_IMPL_END } -#if defined(USE_TENSORRT) || defined(USE_CUDA) || defined(USE_CANN) || defined(USE_DNNL) || defined(USE_ROCM) || defined(USE_NV) +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) || \ + defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || \ + defined(USE_CANN) || \ + defined(USE_DNNL) || \ + defined(USE_ROCM) static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterator& begin, const onnxruntime::ProviderOptions::iterator& end) { std::ostringstream options; @@ -2602,7 +2674,7 @@ static std::string BuildOptionsString(const onnxruntime::ProviderOptions::iterat ORT_API_STATUS_IMPL(OrtApis::GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr) { API_IMPL_BEGIN -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) onnxruntime::ProviderOptions options = onnxruntime::GetProviderInfo_Tensorrt(tensorrt_options); std::string options_str = BuildOptionsString(options.begin(), options.end()); *ptr = onnxruntime::StrDup(options_str, allocator); @@ -2621,7 +2693,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptionsWithValue, _In_ const char* key, _In_ void* value) { API_IMPL_BEGIN -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) // current provider option that has pointer data type (excluding const char*) is 'user_compute_stream' if (strcmp(key, "user_compute_stream") == 0) { tensorrt_options->has_user_compute_stream = 1; @@ -2646,7 +2718,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorRTProviderOptionsByName, _In_ const char* key, _Outptr_ void** ptr) { API_IMPL_BEGIN -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) // current provider option that has pointer data type (excluding const char*) is 'user_compute_stream' if (strcmp(key, "user_compute_stream") == 0) { *ptr = tensorrt_options->user_compute_stream; @@ -2664,7 +2736,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorRTProviderOptionsByName, } ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* ptr) { -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) if (ptr != nullptr) { delete[] ptr->trt_int8_calibration_table_name; delete[] ptr->trt_engine_cache_path; @@ -2702,7 +2774,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA_V2, _In_ ORT_API_STATUS_IMPL(OrtApis::CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) auto options = std::make_unique(); *out = options.release(); return nullptr; @@ -2719,7 +2791,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptions, _In_reads_(num_keys) const char* const* provider_options_values, size_t num_keys) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) onnxruntime::ProviderOptions provider_options_map; for (size_t i = 0; i != num_keys; ++i) { if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || @@ -2746,7 +2818,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptions, ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) onnxruntime::ProviderOptions options = onnxruntime::GetProviderInfo_Cuda(cuda_options); std::string options_str = BuildOptionsString(options.begin(), options.end()); *ptr = onnxruntime::StrDup(options_str, allocator); @@ -2765,7 +2837,7 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue, _In_ const char* key, _In_ void* value) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (strcmp(key, "user_compute_stream") == 0) { cuda_options->has_user_compute_stream = 1; cuda_options->user_compute_stream = value; @@ -2785,7 +2857,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName, _In_ const char* key, _Outptr_ void** ptr) { API_IMPL_BEGIN -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (strcmp(key, "user_compute_stream") == 0) { *ptr = cuda_options->user_compute_stream; } else { @@ -2802,7 +2874,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName, } ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) std::unique_ptr p(ptr); #else ORT_UNUSED_PARAMETER(ptr); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 82201741cb047..2108626e36853 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -26,7 +26,7 @@ #include "core/providers/dml/dml_provider_factory_creator.h" #endif -#if defined(USE_NV) +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #endif using namespace onnxruntime; @@ -150,6 +150,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + auto create_failed_to_load_provider_status = [&provider_name]() { + return OrtApis::CreateStatus(ORT_FAIL, + (std::string("Failed to load provider ") + provider_name).c_str()); + }; + auto create_unknown_provider_status = [&provider_name](gsl::span supported_eps) -> OrtStatus* { std::ostringstream str_builder; str_builder << "Unknown provider name '" << provider_name << "'. " @@ -200,17 +205,24 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, break; } case EpID::QNN: { -#if defined(USE_QNN) - options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value))); +#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) + if (auto ep_factory = QNNProviderFactoryCreator::Create(provider_options, &(options->value)); ep_factory) { + options->provider_factories.push_back(std::move(ep_factory)); + } else { + status = create_failed_to_load_provider_status(); + } #else status = create_not_supported_status(); #endif break; } case EpID::OpenVINO: { -#if defined(USE_OPENVINO) - options->provider_factories.push_back(OpenVINOProviderFactoryCreator::Create(&provider_options, - &(options->value))); +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) + if (auto ep_factory = OpenVINOProviderFactoryCreator::Create(&provider_options, &(options->value)); ep_factory) { + options->provider_factories.push_back(std::move(ep_factory)); + } else { + status = create_failed_to_load_provider_status(); + } #else status = create_not_supported_status(); #endif @@ -271,7 +283,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, break; } case EpID::VitisAI: { -#ifdef USE_VITISAI +#if defined(USE_VITISAI) || defined(USE_VITISAI_PROVIDER_INTERFACE) status = OrtApis::SessionOptionsAppendExecutionProvider_VitisAI(options, provider_options_keys, provider_options_values, num_keys); #else @@ -288,12 +300,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, break; } case EpID::NvTensorRtRtx: { -#if defined(USE_NV) - auto factory = onnxruntime::NvProviderFactoryCreator::Create(provider_options); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Nv_TensorRT_RTX: Failed to load shared library"); +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) + auto factory = onnxruntime::NvProviderFactoryCreator::Create(provider_options, &(options->value)); + if (factory) { + options->provider_factories.push_back(factory); + } else { + status = create_failed_to_load_provider_status(); } - options->provider_factories.push_back(factory); #else status = create_not_supported_status(); #endif @@ -301,6 +314,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } default: ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(create_failed_to_load_provider_status); status = create_unknown_provider_status(supported_eps); } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 2b0e2035b5ee6..adb019fdde86d 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -240,15 +240,25 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, const ORTCHAR_T* library_path, std::unique_ptr& ep_library, std::vector& internal_factories) { + // If the `library_path` is absolute, use it as-is. Otherwise follow the precedent of ProviderLibrary::Load and make + // it absolute by combining it with the OnnxRuntime location. + std::filesystem::path resolved_library_path{library_path}; + + if (!resolved_library_path.is_absolute()) { + resolved_library_path = Env::Default().GetRuntimePath() / std::move(resolved_library_path); + } + // if it's a provider bridge library we need to create ProviderLibrary first to ensure the dependencies are loaded // like the onnxruntime_provider_shared library. - auto provider_library = std::make_unique(library_path); + auto provider_library = std::make_unique(resolved_library_path.native().c_str(), + true, + ProviderLibraryPathType::Absolute); bool is_provider_bridge = provider_library->Load() == Status::OK(); // library has GetProvider LOGS_DEFAULT(INFO) << "Loading EP library: " << library_path << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, library_path); + auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 60bc6865b2ccf..0f15c5fbbdba0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -421,7 +421,7 @@ static std::unique_ptr LoadExecutionProvider( return ep_factory->CreateProvider(); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) const CUDAExecutionProviderInfo GetCudaExecutionProviderInfo(ProviderInfo_CUDA* cuda_provider_info, const ProviderOptionsMap& provider_options_map) { ORT_ENFORCE(cuda_provider_info); @@ -475,7 +475,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* } #endif -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { @@ -507,7 +507,7 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti } #endif -#ifdef USE_NV +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) void RegisterNvTensorRTRtxPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* nv_tensorrt_rtx_provider_info = TryGetProviderInfo_Nv()) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { @@ -548,7 +548,7 @@ std::unique_ptr CreateExecutionProviderInstance( session_options.enable_cpu_mem_arena) ->CreateProvider(); } else if (type == kTensorrtExecutionProvider) { -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) // If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case // as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. if (Env::Default().GetEnvironmentVar("ORT_TENSORRT_UNAVAILABLE").empty()) { @@ -885,12 +885,13 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kNvTensorRTRTXExecutionProvider) { -#ifdef USE_NV +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) if (Env::Default().GetEnvironmentVar("ORT_NV_TENSORRT_RTX_UNAVAILABLE").empty()) { auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { ProviderOptions info = it->second; - if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create(info)) { + if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create( + info, &session_options)) { return nv_tensorrt_rtx_provider_factory->CreateProvider(); } } else { @@ -1033,7 +1034,7 @@ std::unique_ptr CreateExecutionProviderInstance( } #endif } else if (type == kCudaExecutionProvider) { -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) // If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda. // This is set by _ld_preload for the manylinux case as in that case, // trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. @@ -1050,6 +1051,7 @@ std::unique_ptr CreateExecutionProviderInstance( return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); } } +#if defined(USE_CUDA) LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Require cuDNN " << CUDNN_MAJOR << ".* and " @@ -1060,7 +1062,15 @@ std::unique_ptr CreateExecutionProviderInstance( << ". Please install all dependencies as mentioned in the GPU requirements page" " (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), " "make sure they're in the PATH, and that your GPU is supported."; -#endif +#elif defined(USE_CUDA_PROVIDER_INTERFACE) + // Can't include "cuda.h", so don't print version info. + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please install all dependencies as mentioned in the GPU requirements page" + " (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), " + "make sure they're in the PATH, and that your GPU is supported."; +#endif // defined(USE_CUDA) +#endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) } else if (type == kRocmExecutionProvider) { #ifdef USE_ROCM if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) { @@ -1111,7 +1121,7 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider(); #endif } else if (type == kOpenVINOExecutionProvider) { -#ifdef USE_OPENVINO +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) ProviderOptions OV_provider_options_map; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { @@ -1192,14 +1202,21 @@ std::unique_ptr CreateExecutionProviderInstance( } #endif } else if (type == kVitisAIExecutionProvider) { -#ifdef USE_VITISAI +#if defined(USE_VITISAI) || defined(USE_VITISAI_PROVIDER_INTERFACE) ProviderOptions info{}; const auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { info = it->second; } info["session_options"] = std::to_string((uintptr_t)(void*)&session_options); - return onnxruntime::VitisAIProviderFactoryCreator::Create(info)->CreateProvider(); + if (auto vitisai_factory = onnxruntime::VitisAIProviderFactoryCreator::Create(info); vitisai_factory) { + return vitisai_factory->CreateProvider(); + } + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/" + << "Vitis-AI-ExecutionProvider.html#requirements to ensure all dependencies are met."; #endif } else if (type == kAclExecutionProvider) { #ifdef USE_ACL @@ -1315,11 +1332,18 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider(); #endif } else if (type == kQnnExecutionProvider) { -#ifdef USE_QNN +#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) auto cit = provider_options_map.find(type); - return onnxruntime::QNNProviderFactoryCreator::Create( - cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) - ->CreateProvider(); + auto qnn_factory = onnxruntime::QNNProviderFactoryCreator::Create( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options); + if (qnn_factory) { + return qnn_factory->CreateProvider(); + } + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/QNN-ExecutionProvider.html" + << " to ensure all dependencies are met."; #endif } else { // check whether it is a dynamic load EP: @@ -1442,9 +1466,8 @@ bool CheckIfTensor(const std::vector& def_list, return type_proto.has_tensor_type(); } -#if defined(USE_OPENVINO) || \ - defined(USE_CUDA) || \ - defined(USE_ROCM) +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) || \ + defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) static void LogDeprecationWarning( const std::string& deprecated, const optional& alternative = nullopt) { LOGS_DEFAULT(WARNING) << "This is DEPRECATED and will be removed in the future: " << deprecated; @@ -1506,10 +1529,10 @@ void addGlobalMethods(py::module& m) { } }); -#ifdef USE_OPENVINO +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { - if (auto* info = GetProviderInfo_OpenVINO()) { + if (auto* info = TryGetProviderInfo_OpenVINO()) { return info->GetAvailableDevices(); } return {}; @@ -1535,7 +1558,7 @@ void addGlobalMethods(py::module& m) { "Gets the dynamically selected OpenVINO device type for inference."); #endif -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) || defined(USE_ROCM) /* * The following set_* methods are deprecated. * @@ -1583,13 +1606,13 @@ void addGlobalMethods(py::module& m) { }); #endif -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) m.def( "register_tensorrt_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterTensorRTPluginsAsCustomOps(so, options); }, "Register TensorRT plugins as custom ops."); #endif -#ifdef USE_NV +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) m.def( "register_nv_tensorrt_rtx_plugins_as_custom_ops", [](PySessionOptions& so, const ProviderOptions& options) { RegisterNvTensorRTRtxPluginsAsCustomOps(so, options); }, "Register NV TensorRT RTX plugins as custom ops."); diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 55ea264571220..c3ca74526de03 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -9,7 +9,7 @@ namespace py = pybind11; const std::string onnxruntime::python::SessionObjectInitializer::default_logger_id = "Default"; -#ifdef USE_OPENVINO +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) // TODO remove deprecated global config std::string openvino_device_type; #endif @@ -19,7 +19,7 @@ OrtDevice::DeviceId cuda_device_id = 0; // TODO remove deprecated global config size_t gpu_mem_limit = std::numeric_limits::max(); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) // TODO remove deprecated global config OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; // TODO remove deprecated global config diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 168880517c3a5..3ae5c0d289c21 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -49,7 +49,7 @@ struct OrtStatus { #define BACKEND_MIGRAPHX "" #endif -#ifdef USE_OPENVINO +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) #if OPENVINO_CONFIG_CPU #define BACKEND_OPENVINO "-OPENVINO_CPU" @@ -70,11 +70,14 @@ struct OrtStatus { #elif OPENVINO_DISABLE_NPU_FALLBACK #define BACKEND_OPENVINO "-OPENVINO_DISABLE_NPU_FALLBACK" -#endif #else #define BACKEND_OPENVINO "" -#endif +#endif // OPENVINO_CONFIG_* + +#else +#define BACKEND_OPENVINO "" +#endif // defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" @@ -112,7 +115,7 @@ struct OrtStatus { #define BACKEND_WEBGPU "" #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #endif @@ -120,20 +123,20 @@ struct OrtStatus { #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/rocm/rocm_execution_provider_info.h" #endif -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) #include "core/providers/tensorrt/tensorrt_provider_factory.h" #endif -#ifdef USE_NV +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) #include "core/providers/nv_tensorrt_rtx/nv_provider_factory.h" #endif #ifdef USE_MIGRAPHX #include "core/providers/migraphx/migraphx_provider_factory.h" #endif -#ifdef USE_OPENVINO +#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) #include "core/providers/openvino/openvino_provider_factory.h" // TODO remove deprecated global config namespace onnxruntime { -ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO(); +ProviderInfo_OpenVINO* TryGetProviderInfo_OpenVINO(); namespace python { extern std::string openvino_device_type; } @@ -153,7 +156,7 @@ extern std::string openvino_device_type; #include "core/providers/cann/cann_execution_provider_info.h" #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) namespace onnxruntime { ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); ProviderInfo_CUDA& GetProviderInfo_CUDA(); @@ -170,14 +173,14 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; } // namespace onnxruntime #endif -#ifdef USE_TENSORRT +#if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) namespace onnxruntime { ProviderInfo_TensorRT* TryGetProviderInfo_TensorRT(); ProviderInfo_TensorRT& GetProviderInfo_TensorRT(); } // namespace onnxruntime #endif -#ifdef USE_NV +#if defined(USE_NV) || defined(USE_NV_PROVIDER_INTERFACE) namespace onnxruntime { ProviderInfo_Nv* TryGetProviderInfo_Nv(); ProviderInfo_Nv& GetProviderInfo_Nv(); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py index 5f24867901570..01d51099ca577 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py @@ -12,7 +12,7 @@ import pytest from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64)) +max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) def dtype_to_suffix(dtype): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 8a6713f6e03a1..aedcc0c5b71ce 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -13,7 +13,7 @@ import pytest from utils import dtype_to_suffix, matmul, softmax -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64)) +max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) def multinormal_distribution(num_distribution, num_element_per_dist): diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py index b8c9c6f6a4ab6..4c7ed67f44c6b 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py @@ -12,7 +12,7 @@ import pytest from utils import get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix -max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", 64)) +max_batch_size = int(os.environ.get("KERNEL_EXPLORER_BATCHED_GEMM_MAX_BATCH_SIZE", "64")) def dtype_to_suffix(dtype): diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 8eea6b5c9fadc..85ac77be2af31 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -992,7 +992,7 @@ def compute_percentile(self): thresholds_dict[tensor] = (thresholds_dict[tensor][0], max_value) thresholds_dict[tensor] = (*thresholds_dict[tensor], *hist[:2]) # Plot histogram for debug only - if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"): apply_plot(hist, hist_edges) return thresholds_dict @@ -1013,7 +1013,7 @@ def compute_entropy(self): thresholds_dict[tensor] = (*optimal_threshold, *histogram[:2]) # Plot histogram for debug only - if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"): apply_plot(histogram[0], histogram[1]) return thresholds_dict @@ -1075,7 +1075,7 @@ def compute_distribution(self): ) # Plot histogram for debug only - if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"): apply_plot(hist, hist_edges) return thresholds_dict diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py index 2e8ee11e2f864..66bac6321f41a 100644 --- a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -138,7 +138,7 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] - if len(graph_attrs): + if graph_attrs: kwargs = {} for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index 449c5988438f9..b1d58b713eea8 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -1181,7 +1181,7 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] - if len(graph_attrs): + if graph_attrs: kwargs = {} for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 4e3ef5febf382..5eba6c386b943 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -342,7 +342,7 @@ def __replace_gemm_with_matmul(graph_path): graph = graph_path[-1] for node in graph.node: graph_attrs = [attr for attr in node.attribute if attr.type == 5 or attr.type == 10] - if len(graph_attrs): + if graph_attrs: kwargs = {} for attr in node.attribute: if attr.type == 5: diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 424f9b7e180a3..b0a78281041d0 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -85,7 +85,7 @@ def __init__( self.tensor_names.update({ot.name: 1 for ot in model.graph.output}) self.tensor_names.update({it.name: 1 for it in model.graph.input}) for node in self.model.model.graph.node: - self.tensor_names.update({output_name: 1 for output_name in node.output}) + self.tensor_names.update(dict.fromkeys(node.output, 1)) if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index 6b8a389824b2d..78731b35c9f4a 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -19,7 +19,7 @@ def is_B_transposed(gemm_node): # noqa: N802 transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"] # noqa: N806 - if len(transB_attribute): + if transB_attribute: return onnx.helper.get_attribute_value(transB_attribute[0]) > 0 return False @@ -27,7 +27,7 @@ def is_B_transposed(gemm_node): # noqa: N802 def get_beta(gemm_node): beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"] - if len(beta_attribute): + if beta_attribute: return onnx.helper.get_attribute_value(beta_attribute[0]) return 1.0 @@ -35,7 +35,7 @@ def get_beta(gemm_node): def set_default_beta(gemm_node): beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"] - if len(beta_attribute): + if beta_attribute: beta_attribute[0].f = 1.0 return 1.0 diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index dd2921e8b69c2..48cd1c52be2e2 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -869,7 +869,7 @@ def default(self, obj): file.write(buf) # Deserialize data (for validation) - if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"): cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0) dict_len = cal_table.DictLength() for i in range(dict_len): diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 3dd2c2ef945ec..d159bbefb41c7 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -279,7 +279,7 @@ def output_summary(results, csv_filename, args): headers = {k: v for k, v in result.items() if k in header_names} if not row: row.update(headers) - row.update({k: "" for k in data_names}) + row.update(dict.fromkeys(data_names, "")) else: for k in header_names: assert row[k] == headers[k] diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 0f0c12b0e0200..2b7fbffa842f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -126,8 +126,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [1, 0, 0, 0, 0], ) if k_nodes is None: - logger.debug("fuse_conformer_attention: failed to match k path") - return + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0], + ) + if k_nodes is None: + logger.debug("fuse_conformer_attention: failed to match k path") + return else: concat_k = k_nodes[1] concat_parent = self.model.get_parent(concat_k, 0, None) @@ -188,7 +194,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") return - self.increase_counter(new_node.op_type) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 5870a031086ee..2b19ae5029ecc 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -39,6 +39,7 @@ def ort_type_to_numpy_type(ort_type: str): "tensor(float)": numpy.float32, "tensor(float16)": numpy.float16, "tensor(bool)": bool, + "tensor(uint8)": numpy.uint8, } if ort_type not in ort_type_to_numpy_type_map: raise ValueError(f"{ort_type} not found in map") @@ -53,6 +54,7 @@ def ort_type_to_torch_type(ort_type: str): "tensor(float)": torch.float32, "tensor(float16)": torch.float16, "tensor(bool)": torch.bool, + "tensor(uint8)": torch.uint8, } if ort_type not in ort_type_to_torch_type_map: raise ValueError(f"{ort_type} not found in map") @@ -68,6 +70,7 @@ def numpy_type_to_torch_type(numpy_type: numpy.dtype): numpy.float32: torch.float32, numpy.float16: torch.float16, bool: torch.bool, + numpy.uint8: torch.uint8, } if numpy_type not in numpy_type_to_torch_type_map: raise ValueError(f"{numpy_type} not found in map") @@ -82,6 +85,7 @@ def torch_type_to_numpy_type(torch_type: torch.dtype): torch.float32: numpy.float32, torch.float16: numpy.float16, torch.bool: bool, + torch.uint8: numpy.uint8, } if torch_type not in torch_type_to_numpy_type_map: raise ValueError(f"{torch_type} not found in map") diff --git a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py index 680b3455ade2d..5d0c4558d8d9c 100644 --- a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py +++ b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py @@ -176,7 +176,7 @@ def output_summary(results: list[dict[str, Any]], csv_filename: str, metric_name # Metric value for given pair of batch_size and sequence_length. # Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only. values = {} - values.update({k: "" for k in key_names}) + values.update(dict.fromkeys(key_names, "")) for result in results: if result["onnx_path"] == model and result[metric_name]: diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index f44aecc51ca34..7794f4d9fefee 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -25,7 +25,7 @@ # to patch transformers before exporting for transformers >= 4.45 from models.torch_export_patches import bypass_export_some_errors -from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes, replace_dynamic_shapes +from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -165,62 +165,18 @@ def run_dynamo_export( llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"} ) - if version.Version(torch.__version__) < version.Version("2.7"): - # This section is only needed for torch==2.6. The workaround implemented here - # to fix bugs is not necessary with torch>=2.7. - # - strings are not allowed with torch 2.6, so we replace them by DYNAMIC - # - TypePromotion was fixed in torch==2.7 - from onnxscript import opset18 as op - - dynamic_shapes = replace_dynamic_shapes( - dynamic_shapes, - dict(batch_size=torch.export.Dim("batch_size")), - default_value=torch.export.Dim.DYNAMIC, + with bypass_export_some_errors(patch_transformers=True): + torch.onnx.export( + llama, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + optimize=True, ) - # TypePromotion cannot fix a type issue after the conversion. - # We insert an additional CastLike when the exporter - def custom_aten_ge(self, other): - if isinstance(other, (int, float)): - return op.GreaterOrEqual(self, op.CastLike(other, self)) - return op.GreaterOrEqual(self, other) - - with bypass_export_some_errors(patch_transformers=True): - # ONNX pass TypePromotion crashes for torch 2.6. - # It can be bypassed by exporting first into an exported program. - # We then need to apply run_decompositions() before onnx conversion starts. - ep = torch.export.export( - llama, - (), - kwargs=model_kwargs, - dynamic_shapes=dynamic_shapes, - strict=False, - ) - ep = ep.run_decompositions() - torch.onnx.export( - ep, - (), - temp_path, - kwargs=model_kwargs, - dynamic_shapes=dynamic_shapes, - dynamo=True, - verbose=args.verbose, - optimize=True, - custom_translation_table={torch.ops.aten.ge.Scalar: custom_aten_ge}, - ) - else: - with bypass_export_some_errors(patch_transformers=True): - torch.onnx.export( - llama, - (), - temp_path, - kwargs=model_kwargs, - dynamic_shapes=dynamic_shapes, - dynamo=True, - verbose=args.verbose, - optimize=True, - ) - # Check decoder_with_past_model.onnx and save all external data to one file onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py index 3b53f60758b27..1fe1afb46149c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -16,9 +16,9 @@ def init_dist(): dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: - int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) - rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0")) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1")) dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) else: diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 9b167aa177fb7..383101c8a3b72 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -13,7 +13,6 @@ import numpy as np import packaging.version as pv import torch -import transformers from benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( @@ -27,6 +26,7 @@ from llama_torch import setup_torch_model from models.torch_export_patches.cache_helper import make_dynamic_cache from transformers import AutoConfig +from transformers import __version__ as transformers_version from transformers.cache_utils import DynamicCache import onnxruntime as ort @@ -105,7 +105,7 @@ def verify_parity( pytorch_model: None | torch.nn.Module = None, config: None | AutoConfig = None, ): - # If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT. + # If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT. py_model = pytorch_model if py_model is None: config, py_model = setup_torch_model( @@ -118,18 +118,19 @@ def verify_parity( inputs = get_inputs(args, config) - if "past_key_values" in inputs and pv.Version(transformers.__version__) >= pv.Version("4.45"): + if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"): # Using DynamicCache inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"]) # Run inference with PyTorch + inputs_after_deepcopy = torch_deepcopy(inputs) if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() - # If there is a cache in the inputs, we need to make a copy as the model modify them inplace. + # If there is a cache in the inputs, we need to make a copy as the model modifies them inplace. # DynamicCache inherits from torch.nn.Module in some version of transformers. # We need to make the copy manually. - pt_outputs = py_model(**torch_deepcopy(inputs)).logits.detach().cpu().numpy() + pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index b4c3278feaf79..6bd698f8b75b4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,8 +1,8 @@ onnxscript>=0.2.3 optimum>=1.14.1 -optree # this is still needed when pytorch==2.6 is used +optree transformers==4.48.0 -torch>=2.2.0 +torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 protobuf==3.20.2 diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index c7e0e31765a4f..21848deaf99fe 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -692,10 +692,10 @@ def output_summary(results, csv_filename, data_field="average_latency_ms"): row = {} sum_latency = {} - sum_latency.update({k: 0 for k in data_names}) + sum_latency.update(dict.fromkeys(data_names, 0)) count_latency = {} - count_latency.update({k: 0 for k in data_names}) + count_latency.update(dict.fromkeys(data_names, 0)) for result in results: if result["description"] == description and result[data_field]: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 99a2d9379598d..4f57f3446f5f6 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -11,7 +11,6 @@ import time from pathlib import Path -import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 24897756b2d7a..eb4d7242f72fc 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -22,7 +22,6 @@ import tempfile from pathlib import Path -import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs import onnx from fusion_options import FusionOptions diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 5bdd422a11750..c212f7ea825b5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -1,6 +1,6 @@ huggingface_hub==0.25.2 diffusers==0.28.0 -transformers==4.41.2 +transformers==4.50.0 numpy>=1.24.1 accelerate onnx==1.17.0 diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py index a9ce1d6144f45..aaf2f1db90ee3 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -4,7 +4,8 @@ import numpy as np import packaging.version as pv import torch -import transformers +from transformers import __version__ as transformers_version +from transformers.cache_utils import DynamicCache, EncoderDecoderCache from .onnx_export_errors import ( bypass_export_some_errors, @@ -422,43 +423,41 @@ def string_type( raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}") -if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): +if pv.Version(transformers_version) > pv.Version("4.49.99999"): def make_dynamic_cache( key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: + ) -> DynamicCache: """ - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + Creates an instance of :class:`DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` + :return: :class:`DynamicCache` """ - return transformers.cache_utils.DynamicCache(key_value_pairs) + return DynamicCache(key_value_pairs) else: def make_dynamic_cache( key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: + ) -> DynamicCache: """ - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + Creates an instance of :class:`DynamicCache`. This version is valid for ``transformers < 4.50``. :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` + :return: :class:`DynamicCache` """ - cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + cache = DynamicCache(len(key_value_pairs)) for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) return cache def make_encoder_decoder_cache( - self_attention_cache: transformers.cache_utils.DynamicCache, - cross_attention_cache: transformers.cache_utils.DynamicCache, -) -> transformers.cache_utils.EncoderDecoderCache: + self_attention_cache: DynamicCache, + cross_attention_cache: DynamicCache, +) -> EncoderDecoderCache: "Creates an EncoderDecoderCache." - return transformers.cache_utils.EncoderDecoderCache( - self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache - ) + return EncoderDecoderCache(self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py index 0cbe1e58a9e02..f5631cccecb28 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py @@ -1,12 +1,12 @@ import packaging.version as pv import torch -import transformers -import transformers.cache_utils +from transformers import __version__ as transformers_version +from transformers.cache_utils import DynamicCache, EncoderDecoderCache def is_cache_dynamic_registered(fast: bool = False) -> bool: """ - Tells class :class:`transformers.cache_utils.DynamicCache` can be + Tells class :class:`DynamicCache` can be serialized and deserialized. Only then, :func:`torch.export.export` can export a model. @@ -14,7 +14,7 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool: :return: result """ if fast: - return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES + return DynamicCache in torch.utils._pytree.SUPPORTED_NODES bsize, nheads, slen, dim = 2, 4, 3, 7 cache = make_dynamic_cache( [ @@ -30,45 +30,43 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool: return len(cache2.key_cache) == len(cache.value_cache) -if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): +if pv.Version(transformers_version) > pv.Version("4.49.99999"): def make_dynamic_cache( key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: + ) -> DynamicCache: """ - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + Creates an instance of :class:`DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` + :return: :class:`DynamicCache` """ - return transformers.cache_utils.DynamicCache(key_value_pairs) + return DynamicCache(key_value_pairs) else: def make_dynamic_cache( key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], - ) -> transformers.cache_utils.DynamicCache: + ) -> DynamicCache: """ - Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + Creates an instance of :class:`DynamicCache`. This version is valid for ``transformers < 4.50``. :param key_value_pairs: list of pairs of (key, values) - :return: :class:`transformers.cache_utils.DynamicCache` + :return: :class:`DynamicCache` """ - cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + cache = DynamicCache(len(key_value_pairs)) for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) return cache def make_encoder_decoder_cache( - self_attention_cache: transformers.cache_utils.DynamicCache, - cross_attention_cache: transformers.cache_utils.DynamicCache, -) -> transformers.cache_utils.EncoderDecoderCache: + self_attention_cache: DynamicCache, + cross_attention_cache: DynamicCache, +) -> EncoderDecoderCache: """ Creates an EncoderDecoderCache. """ - return transformers.cache_utils.EncoderDecoderCache( - self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache - ) + return EncoderDecoderCache(self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 5dd3b38a8232a..cd49dd2d1b0e6 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -71,7 +71,6 @@ def _register_cache_serialization(verbose: int = 0) -> dict[str, bool]: # Cache serialization: to be moved into appropriate packages import packaging.version as pv import torch - import transformers try: from transformers.cache_utils import DynamicCache @@ -110,7 +109,7 @@ def _register_cache_serialization(verbose: int = 0) -> dict[str, bool]: # torch.fx._pytree.register_pytree_flatten_spec( # DynamicCache, _flatten_dynamic_cache_for_fx) # so we remove it anyway - if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(transformers.__version__) >= pv.Version("2.7"): + if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(torch.__version__) >= pv.Version("2.7"): if verbose: print("[_register_cache_serialization] DynamicCache is unregistered first.") _unregister(DynamicCache) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py index d109dd3059480..333e6ef5bee3b 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -1,7 +1,7 @@ from typing import Any import torch -import transformers +from transformers.cache_utils import DynamicCache, MambaCache ############ # MambaCache @@ -25,9 +25,9 @@ # dtype=dtype, # ) def flatten_mamba_cache( - mamba_cache: transformers.cache_utils.MambaCache, + mamba_cache: MambaCache, ) -> tuple[list[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + """Serializes a :class:`MambaCache` with python objects.""" flat = [ (k, getattr(mamba_cache, k)) for k in [ @@ -47,8 +47,8 @@ def unflatten_mamba_cache( values: list[Any], context: torch.utils._pytree.Context, output_type=None, -) -> transformers.cache_utils.MambaCache: - """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" +) -> MambaCache: + """Restores a :class:`MambaCache` from python objects.""" conv_states, ssm_states = values class _config: @@ -64,8 +64,6 @@ def __init__(self): self.conv_kernel = conv_states.shape[3] self.num_hidden_layers = conv_states.shape[0] - from transformers.cache_utils import MambaCache - cache = MambaCache( _config(), max_batch_size=1, @@ -84,9 +82,7 @@ def flatten_with_keys_mamba_cache( list[tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - import torch - + """Serializes a :class:`MambaCache` with python objects.""" values, context = flatten_mamba_cache(d) return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context @@ -97,9 +93,9 @@ def flatten_with_keys_mamba_cache( def flatten_dynamic_cache( - dynamic_cache: transformers.cache_utils.DynamicCache, + dynamic_cache: DynamicCache, ) -> tuple[list[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + """Serializes a :class:`DynamicCache` with python objects.""" flat = [(k, getattr(dynamic_cache, k)) for k in ["key_cache", "value_cache"] if hasattr(dynamic_cache, k)] return [f[1] for f in flat], [f[0] for f in flat] @@ -110,9 +106,7 @@ def flatten_with_keys_dynamic_cache( list[tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: - """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - import torch - + """Serializes a :class:`DynamicCache` with python objects.""" values, context = flatten_dynamic_cache(d) return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context @@ -121,10 +115,8 @@ def unflatten_dynamic_cache( values: list[Any], context: torch.utils._pytree.Context, output_type=None, -) -> transformers.cache_utils.DynamicCache: - """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" - from transformers.cache_utils import DynamicCache - +) -> DynamicCache: + """Restores a :class:`DynamicCache` from python objects.""" cache = DynamicCache() values = dict(zip(context, values, strict=False)) for k, v in values.items(): diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index ded05b8c37be5..41ca2f463fdea 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -2,7 +2,7 @@ from typing import Any import torch -import transformers +from transformers.cache_utils import DynamicCache from . import string_type from .cache_helper import make_dynamic_cache @@ -22,7 +22,7 @@ def _process_cache(k: str, v): def _make_shape(subset: dict, cls: type, value: Any) -> Any: - if cls is transformers.cache_utils.DynamicCache: + if cls is DynamicCache: assert subset, "DynamicCache cannot be empty" values = set(map(str, subset.values())) assert len(values) == 1, ( diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 828a883b7ab12..aa762ef5f3b3f 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -4,9 +4,9 @@ from typing import Any import torch -import transformers -import transformers.modeling_attn_mask_utils from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation.utils import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter def _patch_make_causal_mask( @@ -50,11 +50,11 @@ def _patch_make_causal_mask( class patched_AttentionMaskConverter: """ Patches - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + ``AttentionMaskConverter._make_causal_mask``. """ _PATCHES_ = ["_make_causal_mask"] - _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + _PATCHED_CLASS_ = AttentionMaskConverter @staticmethod def _make_causal_mask( @@ -73,11 +73,11 @@ def _make_causal_mask( class patched_AttentionMaskConverter: """ Patches - ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + ``AttentionMaskConverter._make_causal_mask``. """ _PATCHES_ = ["_make_causal_mask"] - _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + _PATCHED_CLASS_ = AttentionMaskConverter @staticmethod def _make_causal_mask( @@ -99,7 +99,7 @@ class patched_DynamicCache: """ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"] - _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache + _PATCHED_CLASS_ = DynamicCache def get_seq_length(self, layer_idx: int | None = 0) -> int: """Returns the sequence length of the cached states. @@ -217,7 +217,7 @@ class patched_GenerationMixin: "_cache_dependant_input_preparation_exporting", "prepare_inputs_for_generation", ] - _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin + _PATCHED_CLASS_ = GenerationMixin def _cache_dependant_input_preparation( self, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt index db2cd95324328..ed4355fcede1f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -1,2 +1,2 @@ -r requirements.txt -onnxruntime>=1.17.1 \ No newline at end of file +onnxruntime>=1.17.1 diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index b88ad3e896ea8..2c82b9ace3c61 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -42,9 +42,10 @@ struct ExampleEp : OrtEp, ApiPtrs { struct ExampleEpFactory : OrtEpFactory, ApiPtrs { ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; } @@ -59,25 +60,56 @@ struct ExampleEpFactory : OrtEpFactory, ApiPtrs { return factory->vendor_.c_str(); } - static bool ORT_API_CALL GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** ep_metadata, - _Out_opt_ OrtKeyValuePairs** ep_options) { - const auto* factory = static_cast(this_ptr); - - if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - // these can be returned as nullptr if you have nothing to add. - factory->ort_api.CreateKeyValuePairs(ep_metadata); - factory->ort_api.CreateKeyValuePairs(ep_options); - - // random example using made up values - factory->ort_api.AddKeyValuePair(*ep_metadata, "version", "0.1"); - factory->ort_api.AddKeyValuePair(*ep_options, "run_really_fast", "true"); + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); - return true; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + // C API + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); + + // OrtEpDevice copies ep_metadata and ep_options. + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices++]); + + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + } + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("run_really_fast", "true"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} } - return false; + return nullptr; } static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, @@ -88,6 +120,7 @@ struct ExampleEpFactory : OrtEpFactory, ApiPtrs { _In_ const OrtLogger* logger, _Out_ OrtEp** ep) { auto* factory = static_cast(this_ptr); + *ep = nullptr; if (num_devices != 1) { // we only registered for CPU and only expected to be selected for one CPU diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index b5d9c81f250c2..619f0a4bcda33 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -58,7 +58,7 @@ template & model_uri, const std::string& ep_to_select, std::optional library_path, - const OrtKeyValuePairs& provider_options, + const Ort::KeyValuePairs& ep_options, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, @@ -75,13 +75,15 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod if (auto_select) { // manually specify EP to select for now - ASSERT_ORTSTATUS_OK(Ort::GetApi().AddSessionConfigEntry(session_options, "test.ep_to_select", - ep_to_select.c_str())); + session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); + // add the provider options to the session options with the required prefix const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); - for (const auto& [key, value] : provider_options.entries) { + std::vector keys, values; + ep_options.GetKeyValuePairs(keys, values); + for (size_t i = 0, end = keys.size(); i < end; ++i) { // add the default value with prefix - session_options.AddConfigEntry((option_prefix + key).c_str(), value.c_str()); + session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); } } else { std::vector devices; @@ -92,9 +94,17 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod DefaultDeviceSelection(ep_to_select, devices); } - ASSERT_ORTSTATUS_OK(Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( - session_options, env, devices.data(), devices.size(), - provider_options.keys.data(), provider_options.values.data(), provider_options.entries.size())); + // C API. Test the C++ API because if it works the C API must also work. + // ASSERT_ORTSTATUS_OK(Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( + // session_options, env, devices.data(), devices.size(), + // provider_options.keys.data(), provider_options.values.data(), provider_options.entries.size())); + std::vector ep_devices; + ep_devices.reserve(devices.size()); + for (const auto* device : devices) { + ep_devices.emplace_back(device); + } + + session_options.AppendExecutionProvider_V2(*ort_env, ep_devices, ep_options); } // if session creation passes, model loads fine @@ -115,7 +125,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod namespace { void RunBasicTest(const std::string& ep_name, std::optional library_path, - const OrtKeyValuePairs& provider_options = {}, + const Ort::KeyValuePairs& provider_options = Ort::KeyValuePairs{}, const std::function&)>& select_devices = nullptr) { const auto run_test = [&](bool auto_select) { std::vector> inputs(1); @@ -149,7 +159,7 @@ TEST(AutoEpSelection, CpuEP) { #if defined(USE_CUDA) TEST(AutoEpSelection, CudaEP) { - OrtKeyValuePairs provider_options; + Ort::KeyValuePairs provider_options; provider_options.Add("prefer_nhwc", "1"); RunBasicTest(kCudaExecutionProvider, "onnxruntime_providers_cuda", provider_options); } @@ -157,7 +167,7 @@ TEST(AutoEpSelection, CudaEP) { #if defined(USE_DML) TEST(AutoEpSelection, DmlEP) { - OrtKeyValuePairs provider_options; + Ort::KeyValuePairs provider_options; provider_options.Add("disable_metacommands", "true"); // checking options are passed through const auto select_devices = [&](std::vector& devices) { @@ -172,6 +182,7 @@ TEST(AutoEpSelection, DmlEP) { if (strcmp(c_api->EpDevice_EpName(ep_device), kDmlExecutionProvider) == 0) { const auto* device = c_api->EpDevice_Device(ep_device); const OrtKeyValuePairs* kvps = c_api->HardwareDevice_Metadata(device); + if (devices.empty()) { // add the first device devices.push_back(ep_device); @@ -179,13 +190,7 @@ TEST(AutoEpSelection, DmlEP) { // if this is available, 0 == best performance auto* perf_index = c_api->GetKeyValue(kvps, "HighPerformanceIndex"); if (perf_index && strcmp(perf_index, "0") == 0) { - devices.push_back(ep_device); - } else { - // let an NVIDIA device override the first device - if (strcmp(c_api->EpDevice_EpVendor(ep_device), "NVIDIA") == 0) { - devices.clear(); - devices[0] = ep_device; - } + devices[0] = ep_device; // replace as this is the higher performance device } } } @@ -204,16 +209,73 @@ TEST(AutoEpSelection, WebGpuEP) { } #endif -TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { +// tests for AutoEP selection related things in the API that aren't covered by the other tests. +TEST(AutoEpSelection, MiscApiTests) { + const OrtApi* c_api = &Ort::GetApi(); + + // nullptr and empty input to OrtKeyValuePairs + { + OrtKeyValuePairs* kvps = nullptr; + c_api->CreateKeyValuePairs(&kvps); + c_api->AddKeyValuePair(kvps, "key1", nullptr); // should be ignored + c_api->AddKeyValuePair(kvps, nullptr, "value1"); // should be ignored + c_api->RemoveKeyValuePair(kvps, nullptr); // should be ignored + + c_api->AddKeyValuePair(kvps, "", "value2"); // empty key should be ignored + ASSERT_EQ(c_api->GetKeyValue(kvps, ""), nullptr); + + c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed + ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + + c_api->ReleaseKeyValuePairs(kvps); + } + + // construct KVP from std::unordered_map + { + std::unordered_map kvps; + kvps["key1"] = "value1"; + kvps["key2"] = "value2"; + Ort::KeyValuePairs ort_kvps(kvps); + ASSERT_EQ(ort_kvps.GetValue("key1"), std::string("value1")); + ASSERT_EQ(ort_kvps.GetValue("key2"), std::string("value2")); + } + + std::vector ep_devices = ort_env->GetEpDevices(); + + // explicit EP selection with Ort::KeyValuePairs for options + { + Ort::SessionOptions session_options; + Ort::KeyValuePairs ep_options; + ep_options.Add("option1", "true"); + session_options.AppendExecutionProvider_V2(*ort_env, {ep_devices[0]}, ep_options); + } + + // explicit EP selection with for options + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + ep_options["option1"] = "true"; + session_options.AppendExecutionProvider_V2(*ort_env, {ep_devices[0]}, ep_options); + } +} + +namespace { +struct ExamplePluginInfo { + const std::filesystem::path library_path = #if _WIN32 - std::filesystem::path library_path = "example_plugin_ep.dll"; + "example_plugin_ep.dll"; #else - std::filesystem::path library_path = "libexample_plugin_ep.so"; + "libexample_plugin_ep.so"; #endif - const std::string registration_name = "example_ep"; +}; - Ort::SessionOptions session_options; +static const ExamplePluginInfo example_plugin_info; +} // namespace + +TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; OrtEnv* c_api_env = *ort_env; const OrtApi* c_api = &Ort::GetApi(); @@ -238,6 +300,48 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(c_api_env, registration_name.c_str())); } + +TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; + + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // should be one device for the example EP + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name](Ort::ConstEpDevice& device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue("version"), "0.1"); + + auto options = test_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); + + // the CPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc new file mode 100644 index 0000000000000..b29fc5181eb46 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -0,0 +1,293 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD +#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) + +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/common/narrow.h" +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_qnbit.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/ort_env.h" +#include "core/util/qmath.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { + +namespace test { + +namespace { + +constexpr int QBits = 8; + +struct TestOptions8Bits { + int64_t M{1}; + int64_t N{1}; + int64_t K{1}; + int64_t block_size{32}; + int64_t accuracy_level{0}; + + bool has_zero_point{false}; + bool has_g_idx{false}; + bool has_bias{false}; + + std::optional output_abs_error{}; + std::optional output_rel_error{}; +}; + +[[maybe_unused]] std::ostream& operator<<(std::ostream& os, const TestOptions8Bits& opts) { + return os << "M:" << opts.M << ", N:" << opts.N << ", K:" << opts.K + << ", block_size:" << opts.block_size + << ", accuracy_level:" << opts.accuracy_level + << ", has_zero_point:" << opts.has_zero_point + << ", has_g_idx:" << opts.has_g_idx + << ", has_bias:" << opts.has_bias; +} + +template +void RunTest8Bits(const TestOptions8Bits& opts) { + SCOPED_TRACE(opts); + + const int64_t M = opts.M, + K = opts.K, + N = opts.N; + + RandomValueGenerator random{1234}; + std::vector input0_fp32_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector input1_fp32_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_rows, q_cols); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(opts.block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); + + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); + + MlasQuantizeBlockwise( + input1_vals.data(), + scales.data(), + opts.has_zero_point ? zp.data() : nullptr, + input1_fp32_vals.data(), + static_cast(opts.block_size), + true, + static_cast(K), + static_cast(N), + static_cast(N), + tp); + + // Note that raw_vals is NxK after dequant + MlasDequantizeBlockwise( + input1_fp32_vals.data(), + input1_vals.data(), + scales.data(), + opts.has_zero_point ? zp.data() : nullptr, + static_cast(opts.block_size), + true, + static_cast(K), + static_cast(N), + tp); + + const std::vector bias_shape = {N}; + const auto bias = [&]() -> std::optional> { + if (opts.has_bias) { + return random.Uniform(bias_shape, 1.0f, 5.0f); + } + return std::nullopt; + }(); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k]; + } + expected_vals[m * N + n] = sum + (bias.has_value() ? (*bias)[n] : 0.0f); + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", opts.block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", opts.accuracy_level); + if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, input0_fp32_vals, false); + } else { + test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_fp32_vals), false); + } + + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); + + if constexpr (std::is_same::value) { + test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + } else { + test.AddInput("scales", {static_cast(q_scale_size)}, FloatsToMLFloat16s(scales), true); + } + + if (opts.has_zero_point) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOptionalInputEdge(); + + if (bias.has_value()) { + if constexpr (std::is_same::value) { + test.AddInput("bias", bias_shape, *bias, true); + } else { + test.AddInput("bias", bias_shape, FloatsToMLFloat16s(*bias), true); + } + } else { + test.AddOptionalInputEdge(); + } + + if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, expected_vals); + } else { + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); + } + + if (opts.output_abs_error.has_value()) { + test.SetOutputAbsErr("Y", *opts.output_abs_error); + } + + if (opts.output_rel_error.has_value()) { + test.SetOutputRelErr("Y", *opts.output_rel_error); + } + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); + test.ConfigEps(std::move(execution_providers)); + test.RunWithConfig(); + execution_providers.clear(); +#else + if constexpr (std::is_same::value) { + if (MlasIsQNBitGemmAvailable(8, 32, SQNBIT_CompInt8)) { + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)); + test.RunWithConfig(); + } + } +#endif +} + +template +void TestMatMul8BitsTyped() { + TestOptions8Bits base_opts{}; + base_opts.M = M, base_opts.N = N, base_opts.K = K; + base_opts.block_size = block_size; + base_opts.accuracy_level = accuracy_level; + + if (base_opts.accuracy_level == 4) { + base_opts.output_abs_error = 0.1f; + base_opts.output_rel_error = 0.02f; + } else if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.055f; + base_opts.output_rel_error = 0.02f; + } + + { + TestOptions8Bits opts = base_opts; + RunTest8Bits(opts); + } + + { + TestOptions8Bits opts = base_opts; + opts.has_zero_point = true; + RunTest8Bits(opts); + } + +// CUDA does not support bias for MatMulNBits +#if not defined(USE_CUDA) + { + TestOptions8Bits opts = base_opts; + opts.has_bias = true; + RunTest8Bits(opts); + } + + { + TestOptions8Bits opts = base_opts; + opts.has_zero_point = true; + opts.has_bias = true; + RunTest8Bits(opts); + } +#endif +} +} // namespace + +TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); +} + +#ifdef USE_CUDA +TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); +} +#endif + +} // namespace test +} // namespace onnxruntime + +#endif +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 15be5e140db6f..8ad3b59ec9bbc 100644 --- a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -135,6 +135,7 @@ static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { } BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. diff --git a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp new file mode 100644 index 0000000000000..fad804f3ce305 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp @@ -0,0 +1,521 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sq8bitgemm_neon.cpp + +Abstract: + + Tests for MatMul8Bits kernels on x86 CPU with input A type T1 fp32. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/lib/qnbitgemm.h" +#include "mlas_qnbit.h" + +class MlasSQ8BitPrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; + MatrixGuardBuffer inputScale_, refScale_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_; + + template + void PrepackB(const uint8_t* src, uint8_t* dst) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + SubBlkLen <= ldb; k += SubBlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k, src + (n + i) * ldb + k + SubBlkLen, dst + n * ldb + 4 * k + i * SubBlkLen); + } + } + + for (size_t kk = 0; kk + k + BlkLen <= ldb; kk += BlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k + kk, src + (n + i) * ldb + k + kk + BlkLen, dst + n * ldb + 4 * k + 4 * kk + i * BlkLen); + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + std::copy(src + n * ldb, src + n * ldb + ldb, dst + n * ldb); + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + BlkPerSubBlk <= BlkCount; k += BlkPerSubBlk) { + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < BlkPerSubBlk; ++j) { + auto srcOffset = (n + i) * BlkCount + k + j; + auto scaleDstOffset = n * BlkCount + 4 * k + i * BlkPerSubBlk + j; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + j) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + for (size_t kk = 0; k + kk < BlkCount; ++kk) { + for (size_t i = 0; i < 4; ++i) { + auto srcOffset = (n + i) * BlkCount + k + kk; + auto scaleDstOffset = n * BlkCount + 4 * k + 4 * kk + i; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + kk) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + auto srcOffset = n * BlkCount + k; + auto scaleDstOffset = n * BlkCount + k; + auto sumDstOffset = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0, N4 = N & (~3), ldbSub = ldb & (~(SubBlkLen - 1)); + for (; n < N4; ++n) { + size_t k = 0; + for (; k < ldbSub && k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(SubBlkLen - 1))) * 4 + (n & 3) * SubBlkLen + (k & (SubBlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + for (; k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(BlkLen - 1))) * 4 + (n & 3) * BlkLen + (k & (BlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + ASSERT_EQ(packedB[n * ldb + k], refB[n * ldb + k]) + << " n " << n << " k " << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + size_t n = 0, N4 = N & (~3), BlkCountSub = BlkCount & (~(BlkPerSubBlk - 1)); + + for (; n < N4; ++n) { + size_t k = 0; + for (; k < BlkCountSub; ++k) { + size_t idx = (n & (~3)) * BlkCount + (k & (~(BlkPerSubBlk - 1))) * 4 + (n & 3) * BlkPerSubBlk + (k & (BlkPerSubBlk - 1)); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + for (; k < BlkCount; ++k) { + size_t idx = (n & (~3)) * BlkCount + k * 4 + (n & 3); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + ASSERT_EQ(packedScale[n * BlkCount + k], refScale[n * BlkCount + k]) + << " n " << n << " k " << k; + } + } + } + + template + void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { + size_t BlkCount = (K + BlkLen - 1) / BlkLen; + + for (size_t n = 0; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + ASSERT_EQ(packedBlkSum[idx], refBlkSum[idx]) + << " n " << n << " k " << k; + } + } + } + + template + void TestPrepack() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackBCount = N * Ldb; + constexpr size_t ScaleCount = BlkCount * N; + const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + + const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }); + + const auto* inputScale = inputScale_.GetFilledBuffer(ScaleCount, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + const auto* inputZp = hasZp ? inputZp_.GetFilledBuffer(ScaleCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }) + : nullptr; + + auto* packedBuffer = packedBuffer_.GetBuffer(BufferSize, true); + auto* refB = refB_.GetBuffer(PackBCount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, hasZp, inputZp, nullptr); + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + + PrepackB(inputB, refB); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum); + + CheckB(refB, reinterpret_cast(packedQuantB.PackedQuantBData)); + CheckScale(refScale, packedQuantB.PackedQuantBScale); + CheckBlkSum(refBlkSum, packedQuantB.QuantBBlkSum); + } + + public: + MlasSQ8BitPrepackTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitPrepack"; + } + + template + void Execute(void) { + TestPrepack(); + TestPrepack(); + } + + void ExecuteShort(void) override { + auto& platform = GetMlasPlatform(); + + if (platform.Avx512Supported_) { + Execute<1, 1, 16, 128>(); + Execute<1, 1, 32, 128>(); + Execute<1, 1, 64, 128>(); + Execute<1, 1, 128, 128>(); + Execute<1, 1, 256, 128>(); + + Execute<16, 4, 16, 128>(); + Execute<32, 4, 16, 128>(); + Execute<64, 4, 16, 128>(); + Execute<128, 4, 16, 128>(); + + Execute<15, 5, 16, 128>(); + Execute<15, 5, 32, 128>(); + Execute<15, 5, 64, 128>(); + Execute<15, 5, 128, 128>(); + Execute<15, 5, 256, 128>(); + + Execute<17, 8, 16, 128>(); + Execute<17, 8, 32, 128>(); + Execute<17, 8, 64, 128>(); + Execute<17, 8, 128, 128>(); + Execute<17, 8, 256, 128>(); + + Execute<256, 16, 16, 128>(); + Execute<257, 17, 32, 128>(); + Execute<255, 15, 64, 128>(); + Execute<256, 17, 128, 128>(); + Execute<257, 16, 256, 128>(); + } else { + Execute<1, 1, 16, 64>(); + Execute<1, 1, 32, 64>(); + Execute<1, 1, 64, 64>(); + Execute<1, 1, 128, 64>(); + Execute<1, 1, 256, 64>(); + + Execute<16, 4, 16, 64>(); + Execute<32, 4, 16, 64>(); + Execute<64, 4, 16, 64>(); + Execute<128, 4, 16, 64>(); + + Execute<15, 5, 16, 64>(); + Execute<15, 5, 32, 64>(); + Execute<15, 5, 64, 64>(); + Execute<15, 5, 128, 64>(); + Execute<15, 5, 256, 64>(); + + Execute<17, 8, 16, 64>(); + Execute<17, 8, 32, 64>(); + Execute<17, 8, 64, 64>(); + Execute<17, 8, 128, 64>(); + Execute<17, 8, 256, 64>(); + + Execute<159, 16, 16, 64>(); + Execute<160, 17, 32, 64>(); + Execute<161, 15, 64, 64>(); + Execute<160, 17, 128, 64>(); + Execute<159, 16, 256, 64>(); + } + } +}; + +class MlasSQ8BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer packedBuffer_, workspace_, packedB_, Zp_; + MatrixGuardBuffer A_, B_, C_, ref_, bias_, scale_; + + bool FloatEqual(float v0, float v1, float rtol, float atol) { + return std::abs(v0 - v1) <= std::abs(v1 * rtol) + atol; + } + + template + void MatMul(const float* A, size_t lda, const float* B, const float* bias, float* C, size_t ldc) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = bias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * lda + k]; + float b = B[n * K + k]; + accu += a * b; + } + C[m * ldc + n] = accu; + } + } + } + + template + void Check(const float* target, const float* ref, size_t ldc, float rtol, float atol) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], rtol, atol)) + << " M " << M << " K " << K << " N " << N << " BlkLen " << BlkLen + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestSQ8BitGemmKernel() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkCount * BlkLen; + constexpr size_t lda = ldb; + constexpr size_t ldc = (N + 15) & (~15); + const auto* A = A_.GetFilledBuffer(M * lda, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* B = B_.GetFilledBuffer(K * N, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape((int)BlkLen, true, (int)K, (int)N, q_rows, q_cols); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + auto* inputB = packedB_.GetBuffer(q_data_size_in_bytes, true); + auto* inputScale = scale_.GetBuffer(q_scale_size, true); + auto* inputZp = HasZp ? Zp_.GetBuffer(q_zp_size_in_bytes, true) : nullptr; + + MlasQuantizeBlockwise( + inputB, + inputScale, + inputZp, + B, + BlkLen, + true, + K, + N, + N, + nullptr); + + MlasDequantizeBlockwise( + B, + inputB, + inputScale, + inputZp, + BlkLen, + true, + K, + N, + nullptr); + + size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, HasZp, inputZp, nullptr); + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + + auto* C = C_.GetBuffer(M * ldc, true); + auto* ref = ref_.GetBuffer(M * ldc, true); + + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }) + : nullptr; + + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* workspace = workspace_.GetBuffer(workspace_size, true); + + MLAS_QNBIT_GEMM_DATA_PARAMS data; + data.A = A; + data.lda = lda; + data.QuantBDataWorkspace = packedBuffer; + data.PackedQuantBData = packedQuantB.PackedQuantBData; + data.QuantBScale = inputScale; + data.QuantBZeroPoint = inputZp; + data.Bias = bias; + data.C = C; + data.ldc = ldc; + + MlasQNBitGemmBatch(M, N, K, 1, 8, BlkLen, SQNBIT_CompInt8, &data, workspace, nullptr); + + MatMul(A, lda, B, bias, ref, ldc); + Check(C, ref, ldc, 0.01f, 0.02f); + } + + public: + MlasSQ8BitGemmKernelTest() + : seed_(1234), gen_(seed_), distrib_f32_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitGemmKernel"; + } + + template + void Execute(void) { + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + } + + void ExecuteShort(void) override { + Execute<1, 1, 1, 16>(); + Execute<7, 128, 4, 16>(); + Execute<8, 497, 5, 16>(); + Execute<1, 3072, 128, 16>(); + Execute<2, 3072, 128, 16>(); + + Execute<1, 1, 1, 32>(); + Execute<8, 33, 5, 32>(); + Execute<8, 513, 9, 32>(); + Execute<1, 3072, 128, 32>(); + Execute<2, 3072, 128, 32>(); + + Execute<1, 1, 1, 64>(); + Execute<8, 497, 9, 64>(); + Execute<1, 3072, 128, 64>(); + Execute<2, 3072, 128, 64>(); + + Execute<1, 1, 1, 128>(); + Execute<6, 255, 7, 128>(); + Execute<5, 257, 9, 128>(); + Execute<1, 3072, 128, 128>(); + Execute<2, 3072, 128, 128>(); + + Execute<1, 1, 1, 256>(); + Execute<7, 255, 7, 256>(); + Execute<6, 257, 7, 256>(); + Execute<1, 3072, 128, 256>(); + Execute<2, 3072, 128, 256>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index bc8b672512d8d..f56f9ffcc7858 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1448,6 +1448,8 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_reflection_padding", "result differs"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); + broken_tests->insert({"rotary_embedding_no_position_ids_expanded", "unknown version"}); + broken_tests->insert({"rotary_embedding_no_position_ids_interleaved_expanded", "unknown version"}); broken_tests->insert({"spacetodepth", "result differs"}); broken_tests->insert({"reduce_sum_square_empty_set_expanded", "unknown version"}); // Fails with QNN SDK 2.17.0: diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc old mode 100755 new mode 100644 index cd10429e5eae0..8b3f55c7df756 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6283,6 +6283,7 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionUnfusableModels) { MODEL_FOLDER "fusion/matmul_scale_unfusable_div_not_scale.onnx", MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_scalar.onnx", MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_not_constant.onnx", + MODEL_FOLDER "fusion/matmul_scale_unfusable_scale_broadcasting_changes_shape.onnx", }; for (const auto& path : unfusable_model_paths) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc old mode 100755 new mode 100644 diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 77800505df9b7..f7760c49d4e79 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -5349,6 +5349,62 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) { #endif } +// Tests that the WeightBiasQuantization optimizer does not process nodes that do not +// already have an output that is consumed by a single QuantizeLinear node. +TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) { + auto test_case = [](bool add_final_reshape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 24, 67, 67}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({24, 1, 5, 5}, -0.1f, 0.1f); + NodeArg* bias_arg = builder.MakeInitializer({24}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_output_arg = add_final_reshape ? builder.MakeIntermediate() : builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.014f, static_cast(127), input_dq_arg); + auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_arg, bias_arg}, {conv_output_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{5, 5}); + conv_node.AddAttribute("strides", std::vector{2, 2}); + conv_node.AddAttribute("group", static_cast(24)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + + // Make adding a final Reshape node configurable to test two cases: + // - Conv produces a graph output + // - Conv output is consumed by some node that is NOT a QuantizeLinear + // In either case, the WeightBiasQuantization optimizer should skip this node. + if (add_final_reshape) { + NodeArg* reshape_output_arg = builder.MakeOutput(); + NodeArg* new_shape_arg = builder.Make1DInitializer({1, -1}); + builder.AddNode("Reshape", {conv_output_arg, new_shape_arg}, {reshape_output_arg}); + } + }; + + auto check_graph = [add_final_reshape](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + + // Should retain the same nodes in the original graph. + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count["Reshape"], static_cast(add_final_reshape)); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + 21, + /*per_sample_tolerance*/ 0.0, + /*relative_per_sample_tolerance*/ 0.0, + std::make_unique()); + }; + + test_case(false); // Conv produces a graph output directly + test_case(true); // Conv -> Reshape -> graph_output +} + TEST(QDQTransformerTests, WeightBiasQuantization_ConvTranspose_Weight) { auto test_case = [](bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 5e9b50c537465..7263c435a6a2e 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -537,7 +537,7 @@ void BaseTester::Run(ExpectResult expect_result, const std::string& expected_fai SessionOptions so; so.use_per_session_threads = false; so.session_logid = test_name_; - so.session_log_verbosity_level = 0; + so.session_log_verbosity_level = 1; so.execution_mode = execution_mode; so.use_deterministic_compute = use_determinism_; so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 73f31787e0597..65bd549294658 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -501,6 +501,18 @@ TEST(Einsum, ExplicitEinsumReduceAxesInInputToScalarOutput) { test.Run(); } +TEST(Einsum, ExplicitEinsumReduceAxesInInputToScalarOutput_Multi_Input) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + // Matrix multiplication first and then reduction to scalar. + // Step1: ij,jk->ik + // Step2: ik-> + test.AddAttribute("equation", "ij,jk->"); + test.AddInput("x", {2, 2}, {1.f, 2.f, 3.f, 4.f}); + test.AddInput("y", {2, 2}, {1.f, 2.f, 3.f, 4.f}); + test.AddOutput("o", {}, {54}); + test.Run(); +} + // Implicit TEST(Einsum, ImplicitEinsumAsElementwiseMulOpWithOneScalar) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index d0069a0069646..400b5ab20930c 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -433,6 +433,48 @@ TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) { .RunWithConfig(); } +TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_3) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Same as GemmBroadcast, but adding the unnecessary first dimension. + test.AddInput("A", {3, 4}, + std::vector(12, static_cast(1.0f))); + test.AddInput("B", {4, 4}, std::vector(16, static_cast(1.0f))); + test.AddInput("C", {1, 1}, std::vector{static_cast(1.0f)}); + test.AddOutput("Y", {3, 4}, + std::vector{static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), + static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), + static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), static_cast(5.0f)}); + test.Config(run_with_tunable_op) + .RunWithConfig(); +} + +TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_4) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Same as GemmBroadcast, but adding the unnecessary first dimension. + test.AddInput("A", {3, 4}, + std::vector(12, static_cast(1.0f))); + test.AddInput("B", {4, 4}, std::vector(16, static_cast(1.0f))); + test.AddInput("C", {3, 1}, std::vector{static_cast(1.0f), static_cast(2.0f), static_cast(3.0f)}); + test.AddOutput("Y", {3, 4}, + std::vector{static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), static_cast(5.0f), + static_cast(6.0f), static_cast(6.0f), static_cast(6.0f), static_cast(6.0f), + static_cast(7.0f), static_cast(7.0f), static_cast(7.0f), static_cast(7.0f)}); + test.Config(run_with_tunable_op) + .RunWithConfig(); +} + TYPED_TEST(GemmOpTypedTests, TestGemmFalseBroadcast) { OpTester test("Gemm"); @@ -914,5 +956,237 @@ TEST(GemmOpTest, SharedPrepackedWeights) { } #endif +TEST(GemmOpTest, GemmOptimizeVec4) { + auto run_test = [](int64_t M, int64_t K, int64_t N) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Matrix A: MxK filled with sequential numbers + std::vector a_data; + a_data.reserve(M * K); + for (int64_t i = 0; i < M * K; ++i) { + a_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix B: KxN filled with sequential numbers + std::vector b_data; + b_data.reserve(K * N); + for (int64_t i = 0; i < K * N; ++i) { + b_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix C: MxN filled with zeros + std::vector c_data(M * N, 1.0f); + + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {K, N}, b_data); + test.AddInput("C", {M, N}, c_data); + + // Calculate expected output + std::vector expected_data(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_data[i * K + k] * b_data[k * N + j]; + } + expected_data[i * N + j] = sum + c_data[i * N + j]; + } + } + + test.AddOutput("Y", {M, N}, expected_data); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + + run_test(60, 16, 92); + + run_test(8, 8, 8); + run_test(128, 128, 128); + run_test(128, 32, 64); + run_test(4, 8, 12); + + run_test(96, 24, 48); + run_test(48, 48, 120); + run_test(72, 80, 84); +} + +TEST(GemmOpTest, GemmOptimizeVec4TransA) { + auto run_test = [](int64_t M, int64_t K, int64_t N) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)1); // A is transposed + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Matrix A: KxM (will be transposed to MxK) filled with sequential numbers + std::vector a_data; + a_data.reserve(M * K); + for (int64_t i = 0; i < K * M; ++i) { + a_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix B: KxN filled with sequential numbers + std::vector b_data; + b_data.reserve(K * N); + for (int64_t i = 0; i < K * N; ++i) { + b_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix C: MxN filled with zeros + std::vector c_data(M * N, 1.0f); + + test.AddInput("A", {K, M}, a_data); // Note dimensions are swapped + test.AddInput("B", {K, N}, b_data); + test.AddInput("C", {M, N}, c_data); + + // Calculate expected output for transposed A + std::vector expected_data(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_data[k * M + i] * b_data[k * N + j]; // Adjusted index for transposed A + } + expected_data[i * N + j] = sum + c_data[i * N + j]; + } + } + + test.AddOutput("Y", {M, N}, expected_data); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + + run_test(60, 16, 92); + run_test(8, 8, 8); + run_test(128, 128, 128); + run_test(128, 32, 64); + run_test(4, 8, 12); + run_test(96, 24, 48); + run_test(48, 48, 120); + run_test(72, 80, 84); +} + +TEST(GemmOpTest, GemmOptimizeVec4TransB) { + auto run_test = [](int64_t M, int64_t K, int64_t N) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); // B is transposed + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Matrix A: MxK filled with sequential numbers + std::vector a_data; + a_data.reserve(M * K); + for (int64_t i = 0; i < M * K; ++i) { + a_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix B: NxK (will be transposed to KxN) filled with sequential numbers + std::vector b_data; + b_data.reserve(K * N); + for (int64_t i = 0; i < N * K; ++i) { + b_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix C: MxN filled with zeros + std::vector c_data(M * N, 1.0f); + + test.AddInput("A", {M, K}, a_data); + test.AddInput("B", {N, K}, b_data); // Note dimensions are swapped + test.AddInput("C", {M, N}, c_data); + + // Calculate expected output for transposed B + std::vector expected_data(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_data[i * K + k] * b_data[j * K + k]; // Adjusted index for transposed B + } + expected_data[i * N + j] = sum + c_data[i * N + j]; + } + } + + test.AddOutput("Y", {M, N}, expected_data); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + + run_test(32, 32, 32); + run_test(60, 16, 92); + run_test(8, 8, 8); + run_test(64, 64, 64); + run_test(128, 128, 128); + run_test(128, 32, 64); + run_test(4, 8, 12); + run_test(96, 24, 48); + run_test(48, 48, 120); + run_test(72, 80, 84); +} + +TEST(GemmOpTest, GemmOptimizeVec4TransAB) { + auto run_test = [](int64_t M, int64_t K, int64_t N) { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)1); // A is transposed + test.AddAttribute("transB", (int64_t)1); // B is transposed + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Matrix A: KxM (will be transposed to MxK) filled with sequential numbers + std::vector a_data; + a_data.reserve(M * K); + for (int64_t i = 0; i < K * M; ++i) { + a_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix B: NxK (will be transposed to KxN) filled with sequential numbers + std::vector b_data; + b_data.reserve(K * N); + for (int64_t i = 0; i < N * K; ++i) { + b_data.push_back(static_cast((i % 7) + 1)); + } + + // Matrix C: MxN filled with zeros + std::vector c_data(M * N, 1.0f); + + test.AddInput("A", {K, M}, a_data); // Note dimensions are swapped + test.AddInput("B", {N, K}, b_data); // Note dimensions are swapped + test.AddInput("C", {M, N}, c_data); + + // Calculate expected output for both matrices transposed + std::vector expected_data(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_data[k * M + i] * b_data[j * K + k]; // Adjusted indices for both transposed + } + expected_data[i * N + j] = sum + c_data[i * N + j]; + } + } + + test.AddOutput("Y", {M, N}, expected_data); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + run_test(32, 32, 32); + run_test(60, 16, 92); + run_test(8, 8, 8); + run_test(128, 128, 128); + run_test(128, 32, 64); + run_test(4, 8, 12); + run_test(96, 24, 48); + run_test(48, 48, 120); + run_test(72, 80, 84); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 9515c8eb78ed6..f3a963ce47eda 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -9,6 +9,8 @@ #include "test/common/trt_op_test_utils.h" #include +#include +#include #include #include #include @@ -22,9 +24,19 @@ namespace onnxruntime { namespace test { -std::string WideToUTF8(const std::wstring& wstr) { +std::string PathToUTF8(const PathString& path) { +#ifdef WIN32 std::wstring_convert> converter; - return converter.to_bytes(wstr); + return converter.to_bytes(path); +#else + return path.c_str(); +#endif +} + +void clearFileIfExists(PathString path) { + if (std::filesystem::exists(path)) { + std::filesystem::remove(path); + } } template @@ -74,10 +86,10 @@ void VerifyOutputs(const std::vector& fetches, const std::vector dims, - bool add_fast_gelu = false) { +static void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu = false) { onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; @@ -143,7 +155,7 @@ void CreateBaseModel(const PathString& model_name, status = onnxruntime::Model::Save(model, model_name); } -Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { +static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { Ort::IoBinding binding(session); auto allocator = Ort::AllocatorWithDefaultOptions(); for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { @@ -178,6 +190,8 @@ Ort::IoBinding generate_io_binding(Ort::Session& session, std::map dims = {1, 3, 2}; @@ -192,9 +206,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str()); - so.AppendExecutionProvider("NvTensorRtRtx", {}); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); Ort::Session session_object(env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -208,9 +222,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AppendExecutionProvider("NvTensorRtRtx", {}); - Ort::Session session_object(env, model_name.c_str(), so); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + Ort::Session session_object(env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -222,6 +236,8 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { PathString model_name = ORT_TSTR("nv_execution_provider_dyn_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_dyn_test_ctx.onnx"); + auto model_name_ctx_str = PathToUTF8(model_name_ctx); + clearFileIfExists(model_name_ctx); std::string graph_name = "test"; std::vector dims = {1, -1, -1}; @@ -236,9 +252,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str()); - so.AppendExecutionProvider("NvTensorRtRtx", {}); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); Ort::Session session_object(env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -252,9 +268,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AppendExecutionProvider("NvTensorRtRtx", {}); - Ort::Session session_object(env, model_name.c_str(), so); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + Ort::Session session_object(env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -269,6 +285,8 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_data_dyn_test_ctx.onnx"); + auto model_name_ctx_str = PathToUTF8(model_name_ctx); + clearFileIfExists(model_name_ctx); std::string graph_name = "test"; std::vector dims = {1, -1, -1}; @@ -283,9 +301,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AddConfigEntry("ep.context_file_path", WideToUTF8(model_name_ctx).c_str()); - so.AppendExecutionProvider("NvTensorRtRtx", {}); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); Ort::Session session_object(env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -299,9 +317,9 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { auto start = std::chrono::high_resolution_clock::now(); Ort::SessionOptions so; Ort::RunOptions run_options; - so.AddConfigEntry("ep.context_enable", "1"); - so.AppendExecutionProvider("NvTensorRtRtx", {}); - Ort::Session session_object(env, model_name.c_str(), so); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + Ort::Session session_object(env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc new file mode 100644 index 0000000000000..55412a7b15d98 --- /dev/null +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" +#include "test/util/include/test_utils.h" + +#include "core/graph/onnx_protobuf.h" +#include "gtest/gtest.h" + +namespace { + +using onnxruntime::Node; +using onnxruntime::NodeArg; +using onnxruntime::ProviderOptions; +using onnxruntime::test::AddQDQNodePair; +using onnxruntime::test::AddQDQNodePairWithOutputAsGraphOutput; +using onnxruntime::test::BuildOpTestCase; +using onnxruntime::test::ExpectedEPNodeAssignment; +using onnxruntime::test::GetTestInputQuantParams; +using onnxruntime::test::GetTestQDQModelFn; +using onnxruntime::test::MakeTestInput; +using onnxruntime::test::ModelTestBuilder; +using onnxruntime::test::QDQTolerance; +using onnxruntime::test::QuantParams; +using onnxruntime::test::RunQnnModelTest; +using onnxruntime::test::TestInputDef; +using onnxruntime::test::TestQDQModelAccuracy; +using onnxruntime::utils::MakeAttribute; + +constexpr char kEinsumOp[] = "Einsum"; +constexpr char kEinsumEquation[] = "equation"; +constexpr char kQnnBackendType[] = "backend_type"; +constexpr char kQnnBackendTypeCpu[] = "cpu"; +constexpr char kQnnBackendTypeHtp[] = "htp"; +constexpr char kOffloadGraphIoQuantization[] = "offload_graph_io_quantization"; +constexpr char kOffloadGraphIoQuantizationDisable[] = "0"; + +template +static void RunQnnEinsum( + const std::string& backend, + const TestInputDef& in0, + const TestInputDef& in1, + const std::string& equation, + const float tolerance) { + ProviderOptions provider_options; + provider_options[kQnnBackendType] = backend; + provider_options[kOffloadGraphIoQuantization] = kOffloadGraphIoQuantizationDisable; + RunQnnModelTest( + /*build_test_case=*/BuildOpTestCase( + /*op_type=*/kEinsumOp, + /*input_defs_1=*/{in0, in1}, + /*input_defs_2=*/{}, + /*attrs=*/{MakeAttribute(kEinsumEquation, equation)}), + /*provider_options=*/provider_options, + /*opset_version=*/12, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/tolerance); +} + +template +GetTestQDQModelFn BuildTestCaseQdq(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Op -> op_output + auto* output = builder.MakeIntermediate(); + Node& node = builder.AddNode(kEinsumOp, op_inputs, {output}); + for (const auto& attr : attrs) { + node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +template +static void RunQnnHtpQdqEinsum(const TestInputDef& in0, + const TestInputDef& in1, + const std::string& equation, + QDQTolerance tolerance) { + ProviderOptions provider_options; + provider_options[kQnnBackendType] = kQnnBackendTypeHtp; + provider_options[kOffloadGraphIoQuantization] = kOffloadGraphIoQuantizationDisable; + std::vector attrs{MakeAttribute(kEinsumEquation, equation)}; + auto f32_model_builder = BuildOpTestCase( + /*op_type=*/kEinsumOp, + /*input_defs_1=*/{in0, in1}, + /*input_defs_2=*/{}, + /*attrs=*/attrs); + auto qdq_model_builder = BuildTestCaseQdq( + /*input_defs=*/{in0, in1}, /*attrs=*/attrs, /*use_contrib_qdq=*/false); + TestQDQModelAccuracy(/*f32_model_fn=*/f32_model_builder, + /*qdq_model_fn=*/qdq_model_builder, + /*qnn_options=*/provider_options, + /*opset_version=*/12, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/tolerance); +} + +} // namespace + +namespace onnxruntime { +namespace test { + +// +// QNN CPU +// + +TEST_F(QnnCPUBackendTests, EinsumRank2) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ab,bc->ac", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMul) { + const std::vector shape0{3, 4, 5, 6}; + const std::vector shape1{3, 4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 6}; + const std::vector shape1{2, 3, 5, 6}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { + const std::vector shape0{1, 9, 1, 7}; + const std::vector shape1{1, 7, 1, 9}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 7, 1, 7}; + const std::vector shape1{1, 9, 1, 7}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-4f); +} + +// +// QNN HTP F16 +// + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +TEST_F(QnnHTPBackendTests, EinsumF16Rank2MatMul) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ij,jk->ik", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMul) { + const std::vector shape0{3, 1, 5, 2}; + const std::vector shape1{3, 1, 2, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll1) { + const std::vector shape0{1, 3, 1, 7}; + const std::vector shape1{1, 7, 1, 3}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll2) { + const std::vector shape0{1, 4, 1, 4}; + const std::vector shape1{1, 9, 1, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-2f); +} + +// +// QNN HTP QDQ +// + +TEST_F(QnnHTPBackendTests, EinsumQdqRank2MatMul) { + const std::vector shape0{2, 3}; + const std::vector shape1{3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"ij,jk->ik", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMul) { + const std::vector shape0{3, 1, 5, 2}; + const std::vector shape1{3, 1, 2, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhij,bhjd->bhid", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeY) { + const std::vector shape0{2, 3, 4, 2}; + const std::vector shape1{2, 3, 5, 2}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhid,bhjd->bhij", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll1) { + const std::vector shape0{1, 3, 1, 7}; + const std::vector shape1{1, 7, 1, 3}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bchq,bkhc->bkhq", + /*tolerance=*/QDQTolerance()); +} + +TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { + const std::vector shape0{1, 4, 1, 4}; + const std::vector shape1{1, 9, 1, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/QDQTolerance()); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index b75751f89a6c7..f736abcd3006d 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -27,6 +27,8 @@ using namespace onnxruntime::logging; // in test_main.cc extern std::unique_ptr ort_env; +extern "C" void ortenv_setup(); +extern "C" void ortenv_teardown(); namespace onnxruntime { namespace test { @@ -1232,6 +1234,37 @@ TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { } #endif // BUILD_QNN_EP_STATIC_LIB +#if !BUILD_QNN_EP_STATIC_LIB +// Tests that loading and unloading of an EP library in the same process does not cause a segfault. +TEST_F(QnnHTPBackendTests, LoadingAndUnloadingOfQnnLibrary_FixSegFault) { + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx"; + + onnxruntime::ProviderOptions options; + options["backend_type"] = "htp"; + options["offload_graph_io_quantization"] = "0"; + + // This first session will load the QNN EP library for the first time. + { + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", options); + + EXPECT_NO_THROW(Ort::Session session(*ort_env, ort_model_path, so)); + } + + { + ortenv_teardown(); // Destroy Env to force unloading of EP libraries. + ortenv_setup(); + + // This next session will reload the QNN EP library. + // Should not get a segfault. + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", options); + + EXPECT_NO_THROW(Ort::Session session(*ort_env, ort_model_path, so)); + } +} +#endif // !BUILD_QNN_EP_STATIC_LIB + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 0eec5f800916f..bfdb1a1a6afdd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -991,6 +991,22 @@ TEST_F(QnnHTPBackendTests, BinaryOp_And4D) { ExpectedEPNodeAssignment::All); } +// Test ScatterND op on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + {}, + 17, + ExpectedEPNodeAssignment::All); +} + // Test that Or is not yet supported on CPU backend. TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { RunOpTest("Or", diff --git a/onnxruntime/test/providers/qnn/upsample_op_test.cc b/onnxruntime/test/providers/qnn/upsample_op_test.cc new file mode 100644 index 0000000000000..3371bbef44e1b --- /dev/null +++ b/onnxruntime/test/providers/qnn/upsample_op_test.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "core/graph/onnx_protobuf.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Upsample operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunUpsampleTestOnCPU(const TestInputDef& input_def, + const TestInputDef& scales_def, + std::vector&& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 9) { + ProviderOptions provider_options; + provider_options["backend_type"] = "cpu"; + provider_options["offload_graph_io_quantization"] = "0"; + + if (opset <= 7) { + const std::vector& scales = scales_def.GetRawData(); + attrs.push_back(utils::MakeAttribute("scales", scales)); + + RunQnnModelTest(BuildOpTestCase("Upsample", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); + } else { + RunQnnModelTest(BuildOpTestCase("Upsample", {input_def}, {scales_def}, attrs), + provider_options, + opset, + expected_ep_assignment); + } +} + +// +// CPU tests: +// + +// Test that Upsample with a dynamic scales input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Upsample_DynamicScales_Unsupported) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({4}, false /* is_initializer */, {1.0f, 1.0f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "nearest")}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 9); // Opset +} + +// Test Upsample with opset-9, mode `nearest` +TEST_F(QnnCPUBackendTests, Upsample_4D_Nearest_opset9) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({4}, true, {1.0f, 1.0f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "nearest")}, // Attributes + ExpectedEPNodeAssignment::All, + 9); // Opset +} + +// Test Upsample with opset-9, mode `linear` +TEST_F(QnnCPUBackendTests, Upsample_4D_Linear_opset9) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({4}, true, {1.0f, 1.0f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "linear")}, // Attributes + ExpectedEPNodeAssignment::All, + 9); // Opset +} + +// Test Upsample with opset-7, mode `nearest` +TEST_F(QnnCPUBackendTests, Upsample_4D_Nearest_opset7) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({4}, true, {1.0f, 1.0f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "nearest")}, // Attributes + ExpectedEPNodeAssignment::All, + 7); // Opset +} + +// Test Upsample with opset-7, mode `linear` +TEST_F(QnnCPUBackendTests, Upsample_4D_Linear_opset7) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({4}, true, {1.0f, 1.0f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "linear")}, // Attributes + ExpectedEPNodeAssignment::All, + 7); // Opset +} + +// Test Upsample 5D +TEST_F(QnnCPUBackendTests, Upsample_5D) { + RunUpsampleTestOnCPU(TestInputDef({1, 3, 4, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({5}, true, {1.0f, 1.0f, 1.5f, 1.5f, 1.5f}), + {utils::MakeAttribute("mode", "nearest")}, // Attributes + ExpectedEPNodeAssignment::All, + 9); // Opset +} + +/* +QNN HTP backend tests for the QDQ Upsample model is bypassed and can not be enabled. + +ONNX Upsample is deprecated in domain version 10. However, ONNX QuantizeLinear and DequantizeLinear are enabled in +domain version 10. Their conditions are mutually exclusive, so it is not possible for these ops to coexist in the +same domain version. +*/ + +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/quantization/bench_matmul_8bits.py b/onnxruntime/test/python/quantization/bench_matmul_8bits.py new file mode 100644 index 0000000000000..6422847c8d9c9 --- /dev/null +++ b/onnxruntime/test/python/quantization/bench_matmul_8bits.py @@ -0,0 +1,670 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of MatMulNBits for CUDA in ONNX Runtime. +""" + +import argparse +import csv +import math +import statistics +from datetime import datetime + +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnxruntime.transformers.io_binding_helper import CudaSession + + +class MatMulNBitsConfig: + """ + Configuration for the MatMulNBits benchmark. + """ + + def __init__( + self, + bits: int, + m: int, + n: int, + k: int, + block_size: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + has_zero_point: bool = True, + has_bias: bool = False, + enable_cuda_graph: bool = False, + ): + """ + Initializes the MatMulNBitsConfig. + + Args: + bits (int): Number of bits for quantization (e.g., 4, 8). + m (int): The M dimension of the MatMul operation (batch size). + n (int): The N dimension of the MatMul operation (output features). + k (int): The K dimension of the MatMul operation (input features). + block_size (int): The block size used for quantization along the K dimension. + device (torch.device): The device to run the benchmark on (e.g., torch.device('cuda:0')). + dtype (torch.dtype, optional): The data type for floating-point inputs and outputs. Defaults to torch.float16. + has_zero_point (bool, optional): Whether the quantized weights have a zero point. Defaults to True. + has_bias (bool, optional): Whether the MatMul operation includes a bias term. Defaults to False. + enable_cuda_graph (bool, optional): Whether to enable CUDA graph capture. Defaults to False. + """ + self.operator = "MatMulNBits" + self.bits = bits + self.m = m + self.n = n + self.k = k + self.block_size = block_size + self.has_zero_point = has_zero_point + self.is_int_zero_point = True + self.has_bias = has_bias + # This script is specifically for CUDA benchmarking + self.use_cuda = True + self.dtype = dtype + self.device = device + self.enable_cuda_graph = enable_cuda_graph + + if self.k % self.block_size != 0: + raise ValueError(f"K ({self.k}) must be divisible by block_size ({self.block_size}).") + + if self.bits not in [4, 8]: + raise ValueError(f"Bits must be 4 or 8, but got {self.bits}.") + + def __repr__(self): + """ + Returns a string representation of the configuration. + """ + return ( + f"{self.operator}(bits={self.bits}, m={self.m}, n={self.n}, k={self.k}, block_size={self.block_size}, " + f"dtype={self.dtype}, has_zero_point={self.has_zero_point}, " + f"has_bias={self.has_bias}, enable_cuda_graph={self.enable_cuda_graph}) " + ) + + def shape_dict(self) -> dict[str, tuple]: + """ + Returns a dictionary mapping input/output names to their shapes. + + Based on the MatMulNBits operator definition, input 'b' (weights) is + quantized and structured as (N, K/block_size, block_size). + Scales and zero_points are (N, K/block_size). + """ + k_blocks = self.k // self.block_size + shapes: dict[str, tuple] = { + "output": (self.m, self.n), + "a": (self.m, self.k), + "b": (self.n, k_blocks, self.block_size), # Quantized weights + "scales": (self.n, k_blocks), + } + if self.has_zero_point: + shapes["zero_points"] = (self.n, k_blocks) + + if self.has_bias: + shapes["bias"] = (self.n,) + + return shapes + + def random_inputs(self, seed: int = 123) -> dict[str, torch.Tensor]: + """ + Generates random input tensors based on the configuration. + + Args: + seed (int, optional): Random seed for reproducibility. Use 0 for no seed. Defaults to 123. + + Returns: + dict[str, torch.Tensor]: A dictionary of input tensors. + """ + device = self.device + dtype = self.dtype + + shape_dict = self.shape_dict() + + if seed > 0: + torch.manual_seed(seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(seed) + + feeds = { + # 'a' is the activation tensor (M, K) + "a": torch.empty(shape_dict["a"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + # 'b' is the quantized weight tensor (N, K/block_size, block_size) + # Values should be within the [0, 2^bits - 1] range, but MatMulNBits takes UINT8. + # The actual range used by the kernel depends on 'bits'. + # Generating [0, 255] and letting the kernel handle the bit interpretation. + "b": torch.randint(0, 256, shape_dict["b"], device=device, dtype=torch.uint8), + # 'scales' is the scale tensor (N, K/block_size) + "scales": torch.empty(shape_dict["scales"], device=device, dtype=dtype).normal_(mean=0, std=10.0), + } + + if self.has_zero_point: + # 'zero_points' is the zero point tensor (N, K/block_size) + # Assuming is_int_zero_point is True, dtype is uint8. + # Values should be within [0, 2^bits - 1]. Generating [0, 255]. + feeds["zero_points"] = torch.randint(0, 256, shape_dict["zero_points"], device=device, dtype=torch.uint8) + + if self.has_bias: + # 'bias' is the bias tensor (N,) + feeds["bias"] = torch.empty(shape_dict["bias"], device=device, dtype=dtype).normal_(mean=0, std=0.1) + + return feeds + + def get_input_output_names(self) -> tuple[list[str], list[str]]: + """ + Returns the list of input and output names for the ONNX model. + """ + inputs = ["a", "b", "scales"] + if self.has_zero_point: + inputs.append("zero_points") + if self.has_bias: + inputs.append("bias") + + outputs = ["output"] + + return inputs, outputs + + +def create_matmul_nbits_onnx_model(config: MatMulNBitsConfig) -> bytes: + """ + Creates an ONNX model with a single MatMulNBits node. + + Args: + config (MatMulNBitsConfig): The configuration for the MatMulNBits node. + + Returns: + bytes: The serialized ONNX model. + """ + input_names, output_names = config.get_input_output_names() + + float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT + nodes = [ + helper.make_node( + "MatMulNBits", + input_names, + output_names, + "MatMulNBits_0", # Node name + bits=config.bits, + block_size=config.block_size, + K=config.k, + N=config.n, + domain="com.microsoft", + ), + ] + + shape_dict = config.shape_dict() + # Input types based on ONNX MatMulNBits definition. 'a', 'scales', 'bias' are float types. + # 'b' and 'zero_points' are UINT8. + inputs = [ + helper.make_tensor_value_info( + input_name, + TensorProto.UINT8 if input_name in ["b", "zero_points"] else float_type, + list(shape_dict[input_name]), + ) + for input_name in input_names + if input_name + ] + + outputs = [ + helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) + for output_name in output_names + if output_name + ] + + graph = helper.make_graph( + nodes, + "MatMulNBits_Graph", + inputs, + outputs, + ) + + model = helper.make_model(graph, producer_name="onnxruntime.benchmarks") + + return model.SerializeToString() + + +def create_ort_session( + config: MatMulNBitsConfig, + session_options: SessionOptions = None, + use_tf32: bool = False, +) -> InferenceSession: + """ + Creates an ONNX Runtime InferenceSession for the MatMulNBits model. + + Args: + config (MatMulNBitsConfig): The configuration for the session. + session_options (SessionOptions, optional): ONNX Runtime session options. Defaults to None. + use_tf32 (bool, optional): Whether to enable TF32 mode on CUDA. Defaults to False. + + Returns: + InferenceSession: The created ONNX Runtime InferenceSession. + """ + onnx_model_str = create_matmul_nbits_onnx_model(config) + + # Assuming CUDA execution provider for this script + if "CUDAExecutionProvider" not in get_available_providers(): + raise RuntimeError("CUDAExecutionProvider is not available.") + + device_id = config.device.index if isinstance(config.device, torch.device) else 0 + provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["use_tf32"] = int(use_tf32) + # Include CPU as fallback, though performance sensitive tests should target CUDA + providers = [("CUDAExecutionProvider", provider_options), "CPUExecutionProvider"] + + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + return ort_session + + +def create_session( + config: MatMulNBitsConfig, session_options: SessionOptions = None, use_tf32: bool = False +) -> CudaSession: + """ + Creates a CudaSession with pre-allocated buffers. + + Args: + config (MatMulNBitsConfig): The configuration for the session. + session_options (SessionOptions, optional): ONNX Runtime session options. Defaults to None. + use_tf32 (bool, optional): Whether to enable TF32 mode on CUDA. Defaults to False. + + Returns: + CudaSession: The created CudaSession. + """ + ort_session = create_ort_session(config, session_options, use_tf32=use_tf32) + cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) + shape_dict = config.shape_dict() + cuda_session.allocate_buffers(shape_dict) + return cuda_session + + +def measure_latency(cuda_session: CudaSession, input_dict: dict[str, torch.Tensor]) -> float: + """ + Measures the inference latency of a single run using CUDA events. + + Args: + cuda_session (CudaSession): The CudaSession to benchmark. + input_dict (dict[str, torch.Tensor]): The input data for inference. + + Returns: + float: The latency in seconds. + """ + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Synchronize before starting the timed event + torch.cuda.synchronize() + start_event.record() + + cuda_session.infer(input_dict, synchronize=False) # Infer without synchronizing inside + + end_event.record() + # Synchronize after the timed event to get accurate duration + torch.cuda.synchronize() + + latency_ms = start_event.elapsed_time(end_event) # Latency in milliseconds + return latency_ms / 1000.0 # Return latency in seconds + + +def flops(m: int, n: int, k: int) -> int: + """ + Calculates the number of floating-point operations (FLOPs) for a MatMul (M, K) @ (K, N). + """ + # MatMul (M, K) @ (K, N) performs M*N*K multiplications and M*N*(K-1) additions. + # For simplicity, often approximated as 2 * M * N * K. + return 2 * m * n * k + + +def tflops_per_second(flop: int, time_seconds: float) -> float: + """ + Calculates TFLOPS (Tera Floating-point Operations Per Second). + + Args: + flop (int): The number of FLOPs. + time_seconds (float): The time taken in seconds. + + Returns: + float: The TFLOPS/second, or 0.0 if time is non-positive or NaN. + """ + if time_seconds > 0 and not math.isnan(time_seconds): + return (flop / time_seconds) / 1e12 + return 0.0 + + +def get_test_configs(args: argparse.Namespace) -> list[tuple]: + """ + Generates a list of test configurations (m, n, k, block_size, bits). + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + list[tuple]: A list of tuples, each representing a configuration (m, n, k, block_size, bits). + """ + if args.phi4: + configs = [] + # Predefined configurations inspired by large language models. + phi_weight_shapes = [ + # (N, K) of MatMul weights in phi4-mini model. + (5120, 3072), + (8192, 3072), + (3072, 8192), + (200064, 3072), + ] + + for bits in [4, 8]: + for m in [1, 256, 1024]: + for block_size in [32, 128]: + for n, k in phi_weight_shapes: + if k % block_size == 0: + configs.append((m, n, k, block_size, bits)) + + configs = sorted(configs) + + else: + # Single configuration from command line arguments + configs = [ + ( + args.m, + args.n, + args.k, + args.block_size, + args.bits, + ), + ] + + return configs + + +def get_compute_capability() -> str: + """ + Gets the CUDA compute capability of the current device. + + Returns: + str: The compute capability in 'major.minor' format, or 'N/A' if CUDA is not available. + """ + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return f"{major}.{minor}" + return "N/A" + + +def run_tflops_test( + csv_writer: csv.DictWriter, + args: argparse.Namespace, +): + """ + Runs the TFLOPS benchmark for the specified configurations. + + Args: + csv_writer (csv.DictWriter): CSV writer object to write results. + args (argparse.Namespace): Parsed command-line arguments. + """ + assert torch.cuda.is_available() + assert "CUDAExecutionProvider" in get_available_providers() + + enable_cuda_graph: bool = not args.disable_cuda_graph + intra_op_num_threads: int = args.intra_op_num_threads + repeats: int = args.repeats + num_warmup_runs: int = args.warmup_runs + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + + configs = get_test_configs(args) + + # Print header to console + print("-" * 120) + print( + f"Benchmarking MatMulNBits on {torch.cuda.get_device_name(device_id)} (Compute Capability: {get_compute_capability()})" + ) + print("-" * 120) + # Updated header format to match CSV columns and improve readability + print( + f"{'CUDA Graph':<12} | {'M':<8} | {'N':<8} | {'K':<8} | {'Bits':<6} | {'Block Size':<10} | {'Threads':<8} | {'Latency (us)':<15} | {'StdDev (us)':<12} | {'TFLOPS':<10}" + ) + print("-" * 120) + + for m, n, k, block_size, bits in configs: + config_str = f"(m={m}, n={n}, k={k}, block_size={block_size}, bits={bits})" + try: + config = MatMulNBitsConfig( + bits=bits, + m=m, + n=n, + k=k, + block_size=block_size, + device=device, + dtype=torch.float16, # Assuming float16 for CUDA performance tests + has_zero_point=True, # Assuming zero point for MatMulNBits + has_bias=False, # Not including bias in these benchmarks by default + enable_cuda_graph=enable_cuda_graph, + ) + + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, use_tf32=args.use_tf32) + input_dict = config.random_inputs() + + # Warm-up runs + for _ in range(num_warmup_runs): + measure_latency(session, input_dict) # Latency is measured, but result is discarded + torch.cuda.synchronize() # Ensure warm-up completes before timing + + # Measure repeats + latency_list_seconds = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list_seconds.append(latency) + + # Explicitly delete session to release GPU memory before processing results + del session + + if not latency_list_seconds: + average_latency_seconds = float("nan") + stddev_latency_seconds = float("nan") + else: + average_latency_seconds = statistics.mean(latency_list_seconds) + stddev_latency_seconds = statistics.stdev(latency_list_seconds) if repeats > 1 else 0.0 + + # compute TFLOPS per second + speed = tflops_per_second( + flops(m, n, k), + average_latency_seconds, + ) + + average_latency_us = average_latency_seconds * 1_000_000 + stddev_latency_us = stddev_latency_seconds * 1_000_000 + + row = { + "use_gpu": True, # Hardcoded to True as this is a CUDA benchmark script + "cuda_graph": enable_cuda_graph, + "m": m, + "n": n, + "k": k, + "bits": bits, + "block_size": block_size, + "intra_op_num_threads": intra_op_num_threads, + "latency_seconds": average_latency_seconds, + "latency_microseconds": average_latency_us, + "latency_stddev_seconds": stddev_latency_seconds, + "latency_stddev_microseconds": stddev_latency_us, + "tflops": speed, + } + csv_writer.writerow(row) + + speed_str = f"{speed:.3f}" if speed is not None and not math.isnan(speed) else "NA" + # Print results to console + print( + f"{enable_cuda_graph!s:<12} | {m:<8} | {n:<8} | {k:<8} | {bits:<6} | {block_size:<10} | {intra_op_num_threads:<8} | {average_latency_us:<15.1f} | {stddev_latency_us:<12.1f} | {speed_str:<10}" + ) + + except ValueError as e: + print(f"Skipping invalid configuration {config_str} - {e}") + # Optionally write a skipped row to CSV? For now, just skip. + continue + except Exception as e: + print(f"Error running benchmark for config {config_str}: {e}") + # Write a row with error info to CSV? Or just skip? Let's just skip for now. + continue + + print("-" * 120) + + +def run_tflops_tests(args): + """ + Sets up the CSV file and runs the benchmark tests. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + """ + csv_filename = "{}{}.csv".format( + args.csv_filename_prefix, + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + print(f"Writing results to {csv_filename}") + + # Use 'w' mode to create a new file for each run + with open(csv_filename, mode="w", newline="") as csv_file: + column_names = [ + "use_gpu", + "cuda_graph", + "m", + "n", + "k", + "bits", + "block_size", + "intra_op_num_threads", + "latency_seconds", + "latency_microseconds", + "latency_stddev_seconds", + "latency_stddev_microseconds", + "tflops", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + # The script is specifically for CUDA now + run_tflops_test(csv_writer, args) + + +def _parse_arguments(): + """ + Parses command-line arguments. + """ + parser = argparse.ArgumentParser( + description="Benchmark MatMulNBits performance for ONNX Runtime CUDAExecutionProvider. " + "Supports both single configurations and predefined Phi-like shapes." + ) + + parser.add_argument( + "--disable_cuda_graph", + action="store_true", + help="Disable CUDA graph capture in ONNX Runtime.", + ) + + parser.add_argument( + "--intra_op_num_threads", + type=int, + choices=[0, 1, 2, 4, 8, 16], # Common thread counts, 0 means default. + default=0, + help="intra_op_num_threads for ONNX Runtime session options. 0 means default.", + ) + + # Arguments for a single configuration + parser.add_argument( + "--m", + type=int, + default=1, + help="The M dimension of the MatMul operation (batch size). Used when --phi4 is not set.", + ) + + parser.add_argument( + "--n", + type=int, + default=200064, # This default seems unusually large, but kept from original. + help="The N dimension of the MatMul operation (output features). Used when --phi4 is not set.", + ) + + parser.add_argument( + "--k", + type=int, + default=3072, + help="The K dimension of the MatMul operation (input features). Used when --phi4 is not set.", + ) + + parser.add_argument( + "--block_size", + type=int, + default=32, + help="The block size used for quantization along the K dimension. Used when --phi4 is not set.", + ) + + parser.add_argument( + "--bits", + type=int, + choices=[4, 8], + default=8, + help="Number of bits for quantization (4 or 8). Used when --phi4 is not set.", + ) + + parser.add_argument( + "-r", + "--repeats", + type=int, + default=10000, # Default repeats for measurement + help="Number of repeats for performance measurement of each configuration.", + ) + + parser.add_argument( + "--warmup_runs", + type=int, + default=10, # Default warmup runs + help="Number of warm-up runs before performance measurement for each configuration.", + ) + + parser.add_argument( + "--phi4", + action="store_true", + help="Run a predefined set of configurations based on Phi4-mini model shapes, overriding --m, --n, --k, --block_size, --bits.", + ) + + parser.add_argument( + "--csv_filename_prefix", + type=str, + default="benchmark_matmulnbits_cuda_", + help="Prefix for the output CSV filename.", + ) + + parser.add_argument( + "--use_tf32", + action="store_true", + help="Enable TF32 mode on CUDA. May affect precision and performance on compatible GPUs (Ampere+).", + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = _parse_arguments() + print(f"Parsed arguments: {args}") + + # Check for CUDA availability early + if not torch.cuda.is_available(): + print("Error: CUDA is not available. This script requires a CUDA-enabled GPU.") + exit(1) + + if "CUDAExecutionProvider" not in get_available_providers(): + print("Error: CUDAExecutionProvider is not available in your ONNX Runtime installation.") + print("Please ensure you have installed the onnxruntime-gpu package (`pip install onnxruntime-gpu`).") + exit(1) + + # Check if k is divisible by block_size for the single config case + if not args.phi4 and args.k % args.block_size != 0: + print( + f"Error: For the single configuration (--phi4 not set), K ({args.k}) must be divisible by block_size ({args.block_size})." + ) + exit(1) + + run_tflops_tests(args) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_8bits.py b/onnxruntime/test/python/quantization/test_op_matmul_8bits.py new file mode 100644 index 0000000000000..6354b7c5fcf0d --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_8bits.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime import get_available_providers +from onnxruntime.quantization import quant_utils + + +@unittest.skipIf( + "CUDAExecutionProvider" not in get_available_providers(), reason="CUDA is not available, skipping tests." +) +class TestOpMatMul8Bits(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmul8bits.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_weight_data(self, shape: tuple[int, ...]) -> np.ndarray: + return np.random.normal(0, 0.01, size=shape).astype(np.float32) + + def input_feeds( + self, + n: int, + name2shape: dict[str, int | tuple[int, ...]], + low: int = -1, + high: int = 2, + dtype: type = np.float32, + ) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(low, high, shape).astype(dtype)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, k: int = 32, n: int = 64) -> None: + """Create a simple onnx model with one MatMul node like (input) --> MatMul --> (output).""" + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul( + input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str, node_name: str + ): + weight_data = self.fill_weight_data(weight_shape) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + node_name, + ) + + in_features = k + out_features = n + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + "MatMul_0", + ) + + # make graph + input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, [-1, in_features]) + output_tensor = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_8bits_test" + graph = onnx.helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + # blocked quantization requires DQ op set >= 21 + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test( + self, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, + op_types_to_quantize: tuple[str, ...] = ("MatMul",), + quant_axes: tuple[tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)), + rtol: float = 0.01, + atol: float = 0.05, + config: str = "default", + suffix: str = "", + ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "QDQ" if use_qdq else "QOperator" + model_int8_path = str( + Path(self._tmp_model_dir.name) + .joinpath(f"{name_prefix}_bs{block_size}_{is_symmetric}{suffix}.onnx") + .absolute() + ) + + # Quantize fp32 model to int8 model + from onnxruntime.quantization import matmul_nbits_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + + assert config in ["default", "hqq"] + if config == "default": + quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, + is_symmetric=is_symmetric, + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + quant_axes=quant_axes, + bits=8, + ) + else: + quant_config = matmul_nbits_quantizer.HQQWeightOnlyQuantConfig( + block_size=block_size, + bits=8, + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + quant_axes=quant_axes, + ) + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, algo_config=quant_config) + quant.process() + quant.model.save_model_to_file(model_int8_path, False) + + if "Gather" in op_types_to_quantize: + quant_nodes = {"GatherBlockQuantized": 1} + else: + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} + check_op_type_count(self, model_int8_path, **quant_nodes) + + if use_qdq: + dq_qtype = onnx.TensorProto.INT8 if is_symmetric else onnx.TensorProto.UINT8 + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) + check_qtype_by_node_type(self, model_int8_path, dqnode_io_qtypes) + for op in quant.model.opset_import(): + if op.domain in [None, "", "ai.onnx"] and op.version < 21: + self.fail(f"In QDQ format {op.domain} opset should be >= 21") + + data_reader.rewind() + + try: + check_model_correctness( + self, + model_fp32_path, + model_int8_path, + data_reader.get_next(), + rtol, + atol, + providers=["CUDAExecutionProvider"], + ) + except Exception as exception: + if "8b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int8 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + + def test_quantize_matmul_8bits(self): + np.random.seed(13) + for k in [32, 40, 256, 512, 512, 1024, 1040]: + for n in [8, 256]: + model_fp32_path = str( + Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_k_{k}_n_{n}.onnx").absolute() + ) + self.construct_model_matmul(model_fp32_path, k=k, n=n) + for m in [1, 2]: + data_reader = self.input_feeds(m, {"input": (m, k)}) + for config in ["default", "hqq"]: + for block_size in [16, 128, 256]: + if block_size <= k: + self.quant_test( + model_fp32_path, + data_reader, + block_size, + True, + atol=0.01, + rtol=0.01, + config=config, + suffix=f"_m_{m}_n_{n}_k_{k}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c5cf8a07f557d..a2cce5b954255 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -86,9 +86,9 @@ def run_beam_search(self, extra_arguments: str, sentences=None, append_arguments arguments = extra_arguments.split() if is_greedy: - arguments.extend("--num_beams 1 --num_return_sequences 1".split()) + arguments.extend(["--num_beams", "1", "--num_return_sequences", "1"]) else: - arguments.extend("--output_sequences_score".split()) + arguments.extend(["--output_sequences_score"]) # Test CPU result = run(arguments, sentences=self.sentences if sentences is None else sentences) diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 8e10ce7ffacc0..3fbb294e1af49 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/common.h" +#include "core/graph/constants.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "gmock/gmock.h" @@ -34,3 +35,103 @@ TEST(CApiTest, session_options_oversized_affinity_string) { } #endif + +#if defined(USE_OPENVINO_PROVIDER_INTERFACE) +// Test that loading OpenVINO EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_add_openvino) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + Ort::Status status = Ort::Status{api.SessionOptionsAppendExecutionProvider(session_options, + kOpenVINOExecutionProvider, + nullptr, nullptr, 0)}; + ASSERT_TRUE(!status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status.GetErrorMessage(), testing::HasSubstr("Failed to load")); +} +#endif // defined(USE_OPENVINO_PROVIDER_INTERFACE) + +#if defined(USE_CUDA_PROVIDER_INTERFACE) +// Test that loading CUDA EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + Ort::Status status1 = Ort::Status{api.CreateCUDAProviderOptions(&cuda_options)}; + ASSERT_TRUE(status1.IsOK()); + + Ort::Status status2 = Ort::Status{api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, + cuda_options)}; + ASSERT_FALSE(status2.IsOK()); + EXPECT_EQ(status2.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status2.GetErrorMessage(), testing::HasSubstr("Failed to load")); + + api.ReleaseCUDAProviderOptions(cuda_options); +} +#endif // defined(USE_CUDA_PROVIDER_INTERFACE) + +#if defined(USE_NV_PROVIDER_INTERFACE) +// Test that loading NV EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_add_nv) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + Ort::Status status = Ort::Status{api.SessionOptionsAppendExecutionProvider(session_options, + kNvTensorRTRTXExecutionProvider, + nullptr, nullptr, 0)}; + ASSERT_TRUE(!status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status.GetErrorMessage(), testing::HasSubstr("Failed to load")); +} +#endif // defined(USE_OPENVINO_PROVIDER_INTERFACE) + +#if defined(USE_TENSORRT_PROVIDER_INTERFACE) +// Test that loading TensorRT EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_add_tensorrt) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + OrtTensorRTProviderOptionsV2* trt_options = nullptr; + Ort::Status status1 = Ort::Status{api.CreateTensorRTProviderOptions(&trt_options)}; + ASSERT_TRUE(status1.IsOK()); + + Ort::Status status2 = Ort::Status{api.SessionOptionsAppendExecutionProvider_TensorRT_V2(session_options, + trt_options)}; + ASSERT_FALSE(status2.IsOK()); + EXPECT_EQ(status2.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status2.GetErrorMessage(), testing::HasSubstr("Failed to load")); + + api.ReleaseTensorRTProviderOptions(trt_options); +} +#endif // defined(USE_TENSORRT_PROVIDER_INTERFACE) + +#if defined(USE_VITISAI_PROVIDER_INTERFACE) +// Test that loading VitisAI EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_vitisai) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + Ort::Status status = Ort::Status{api.SessionOptionsAppendExecutionProvider(session_options, + kVitisAIExecutionProvider, + nullptr, nullptr, 0)}; + ASSERT_TRUE(!status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status.GetErrorMessage(), testing::HasSubstr("Failed to load")); +} +#endif // defined(USE_VITISAI_PROVIDER_INTERFACE) + +#if defined(USE_QNN_PROVIDER_INTERFACE) +// Test that loading QNN EP when only the interface is built (but not the full EP) fails. +TEST(CApiTest, session_options_provider_interface_fail_qnn) { + const OrtApi& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + Ort::Status status = Ort::Status{api.SessionOptionsAppendExecutionProvider(session_options, + kQnnExecutionProvider, + nullptr, nullptr, 0)}; + ASSERT_TRUE(!status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_THAT(status.GetErrorMessage(), testing::HasSubstr("Failed to load")); +} +#endif // defined(USE_QNN_PROVIDER_INTERFACE) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 127f4fd445f4e..341d48342dce8 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -31,12 +31,9 @@ "current_failing_tests": [ "^test_adagrad", "^test_adagrad_multiple", - "^test_batchnorm_epsilon_old", "^test_batchnorm_epsilon_training_mode", - "^test_batchnorm_example_old", "^test_batchnorm_example_training_mode", "^test_col2im_pads", // still one wrong value coming from the backtest example - "^test_gathernd_example_int32_batch_dim1", "^test_max_int16", "^test_max_int8", "^test_max_uint16", @@ -52,8 +49,6 @@ "^test_pow_types_float32_uint64", "^test_gradient_of_add_and_mul", "^test_gradient_of_add", - "^test_batchnorm_example_training_mode", - "^test_batchnorm_epsilon_training_mode", "^test_MaxPool2d_stride_padding_dilation_cpu", // result approximation error; need to be updated in ONNX "^test_maxunpool_export_with_output_shape", // result mismatch "^test_resize_downsample_scales_cubic_align_corners", // results mismatch with onnx tests diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py index afd8259471342..46dd9aa417184 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py @@ -21,10 +21,10 @@ def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets): onnx.save(model, model_path) -def gen(model_path, use_transpose_matmul, scale_input_0, scale_input_1, scale_output): - matmul_op = "FusedMatMul" if use_transpose_matmul else "MatMul" - matmul_domain = "com.microsoft" if use_transpose_matmul else "" - matmul_attrs = {"alpha": scale_value} if use_transpose_matmul else {} +def gen(model_path: str, use_fused_matmul: bool, scale_input_0: bool, scale_input_1: bool, scale_output: bool): + matmul_op = "FusedMatMul" if use_fused_matmul else "MatMul" + matmul_domain = "com.microsoft" if use_fused_matmul else "" + matmul_attrs = {"alpha": scale_value} if use_fused_matmul else {} nodes = [] @@ -85,7 +85,7 @@ def gen(model_path, use_transpose_matmul, scale_input_0, scale_input_1, scale_ou UNFUSABLE_SCALE_NOT_CONSTANT = 2 -def gen_unfusable(model_path, unfusable_type): +def gen_unfusable(model_path: str, unfusable_type: int): matmul_op = "MatMul" if unfusable_type == UNFUSABLE_DIV_NOT_SCALE: @@ -122,7 +122,33 @@ def gen_unfusable(model_path, unfusable_type): gen_unfusable("matmul_scale_unfusable_scale_not_constant.onnx", UNFUSABLE_SCALE_NOT_CONSTANT) -def gen_reused_input_scale(model_path): +def gen_unfusable_scale_broadcast_changes_shape(model_path: str): + matmul_op = "MatMul" + scale_node = helper.make_node("Mul", ["scale_with_leading_dims", "input_0"], ["scaled_input_0"], "scale input_0") + + nodes = [ + scale_node, + helper.make_node(matmul_op, ["scaled_input_0", "input_1"], ["output"], matmul_op), + ] + + initializers = [ + helper.make_tensor("scale_with_leading_dims", TensorProto.FLOAT, [1, 1, 1], [scale_value]), + ] + + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [1, "K", "N"]), + ] + + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, "M", "N"])] + + save(model_path, nodes, inputs, outputs, initializers) + + +gen_unfusable_scale_broadcast_changes_shape("matmul_scale_unfusable_scale_broadcasting_changes_shape.onnx") + + +def gen_reused_input_scale(model_path: str): matmul_op = "MatMul" nodes = [ @@ -160,7 +186,7 @@ def gen_reused_input_scale(model_path): gen_reused_input_scale("matmul_scale_reused_input_scale.onnx") -def gen_int32(model_path): +def gen_int32(model_path: str): matmul_op = "MatMul" nodes = [ @@ -190,7 +216,7 @@ def gen_int32(model_path): gen_int32("matmul_scale_int32.onnx") -def gen_scale_input(model_path): +def gen_scale_input(model_path: str): nodes = [ helper.make_node("Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0"), helper.make_node( diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx index 4b640d6cb1142..66a483b1e7e90 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1.onnx index 9b8d9dc54a0ae..bbd3acc227775 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx index 640a8cca5f006..8e7e198442737 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_in0_in1_out.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_int32.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_int32.onnx index 0048a60d738f2..283329acb4b74 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_int32.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_int32.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_reused_input_scale.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_reused_input_scale.onnx index 661e5d7726da0..2bc23ea251db7 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_reused_input_scale.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_reused_input_scale.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx index 2abc049ba6742..00d5ed889a298 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_transposescalematmul_in0_in1_out.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx index 9290e5589e4f9..7a9bebe49a7f4 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_div_not_scale.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_broadcasting_changes_shape.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_broadcasting_changes_shape.onnx new file mode 100644 index 0000000000000..34ff2d3a53c5f Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_broadcasting_changes_shape.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx index 52c7caf674883..e519ff3bf0676 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_constant.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx index 869c0bdd14ce1..7f966cf0d12f9 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_unfusable_scale_not_scalar.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx index eb59332fe1485..d92180ac5317b 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx differ diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py index 2ae3c98137cbd..ed7bfeb3c96ad 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py @@ -7,6 +7,7 @@ from typing import Optional, TypeVar import torch +from typing_extensions import Self T = TypeVar("T", bound="torch.nn.Module") @@ -44,13 +45,13 @@ def forward(self): def _apply(self, fn): raise NotImplementedError(f"_apply is not implemented for {type(self)}.") - def apply(self: T, fn: Callable[[T], None]) -> T: + def apply(self, fn: Callable[[Self], None]) -> Self: raise NotImplementedError(f"apply is not implemented for {type(self)}.") def is_training(self): raise NotImplementedError(f"is_training is not implemented for {type(self)}.") - def train(self: T, mode: bool = True) -> T: + def train(self, mode: bool = True) -> Self: raise NotImplementedError(f"train is not implemented for {type(self)}.") def state_dict(self, destination=None, prefix="", keep_vars=False): diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index 2ed346fe0bfa6..36433ec5017c9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -8,6 +8,7 @@ from typing import Optional, TypeVar import torch +from typing_extensions import Self from . import _io, _utils from ._fallback import ORTModuleTorchModelException, _FallbackManager, wrap_exception @@ -43,7 +44,7 @@ def _apply(self, fn): self._flattened_module._apply(fn) return self - def apply(self: T, fn: Callable[[T], None]) -> T: + def apply(self, fn: Callable[[Self], None]) -> Self: """Override original method to delegate execution to the flattened PyTorch user module""" # Delegation must happen to _flattened_module since methods depend on @@ -54,7 +55,7 @@ def apply(self: T, fn: Callable[[T], None]) -> T: def is_training(self): return self._flattened_module.training and torch.is_grad_enabled() - def train(self: T, mode: bool = True) -> T: + def train(self, mode: bool = True) -> Self: """Override original method to delegate execution to the flattened PyTorch user module""" # Delegate the task to _module.flattened_module.train which will recursively diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py index 2c38e98cc8657..74792552937dd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py @@ -7,6 +7,7 @@ from typing import Optional, TypeVar import torch +from typing_extensions import Self from ._torch_module_interface import TorchModuleInterface @@ -22,14 +23,14 @@ def _apply(self, fn): self._original_module._apply(fn) return self - def apply(self: T, fn: Callable[[T], None]) -> T: + def apply(self, fn: Callable[[Self], None]) -> Self: self._original_module.apply(fn) return self def is_training(self): return self._original_module.training and torch.is_grad_enabled() - def train(self: T, mode: bool = True) -> T: + def train(self, mode: bool = True) -> Self: self._original_module.train(mode) return self diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index a7942eea5be26..dd56e6986bd65 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -21,6 +21,7 @@ from typing import TypeVar from collections import OrderedDict from collections.abc import Iterator, Callable +from typing_extensions import Self # Needed to override PyTorch methods T = TypeVar("T", bound="torch.nn.Module") @@ -187,7 +188,7 @@ def _apply(self, fn): self._torch_module._apply(fn) return self - def apply(self: T, fn: Callable[[torch.nn.Module], None]) -> T: + def apply(self, fn: Callable[[torch.nn.Module], None]) -> Self: """Override :meth:`~torch.nn.Module.apply` to delegate execution to ONNX Runtime""" self._torch_module.apply(fn) @@ -196,7 +197,7 @@ def apply(self: T, fn: Callable[[torch.nn.Module], None]) -> T: def _is_training(self): return self._torch_module.is_training() - def train(self: T, mode: bool = True) -> T: + def train(self, mode: bool = True) -> Self: """Override :meth:`~torch.nn.Module.train` to delegate execution to ONNX Runtime""" self.training = mode diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index bb0952dea56b7..b26259b8abf94 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -57,7 +57,7 @@ def build_torch_cpp_extensions(): ) # Docker build don't have CUDA support, but Torch C++ extensions with CUDA may be forced - force_cuda = bool(os.environ.get("ONNXRUNTIME_FORCE_CUDA", False)) + force_cuda = bool(os.environ.get("ONNXRUNTIME_FORCE_CUDA", None)) os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py index 043c70263d31e..e334604a6c0a9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py @@ -27,9 +27,9 @@ class ReduceWithMarkDirtyFunction(torch.autograd.Function): def forward(ctx, arg): def reduce(buffer): # All-reduce. - address_for_torch_tensor = int(id(buffer)) + address_for_torch_tensor = id(buffer) torch.distributed.all_reduce(buffer) - address_for_output_torch_tensor = int(id(buffer)) + address_for_output_torch_tensor = id(buffer) if address_for_output_torch_tensor != address_for_torch_tensor: raise ValueError("The output torch tensor should reuse the input torch tensor, but actually not.") return buffer diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 0d5825fb3140e..612bb700c4db8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -319,7 +319,7 @@ def flat_accuracy(preds, labels): def format_time(elapsed): """Takes a time in seconds and returns a string hh:mm:ss""" # Round to the nearest second. - elapsed_rounded = int(round(elapsed)) + elapsed_rounded = int(round(elapsed)) # noqa: RUF046 # Format as hh:mm:ss return str(datetime.timedelta(seconds=elapsed_rounded)) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index 50f411c02a5b5..59b0ca637a4bb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -322,7 +322,7 @@ def flat_accuracy(preds, labels): def format_time(elapsed): """Takes a time in seconds and returns a string hh:mm:ss""" # Round to the nearest second. - elapsed_rounded = int(round(elapsed)) + elapsed_rounded = int(round(elapsed)) # noqa: RUF046 # Format as hh:mm:ss return str(datetime.timedelta(seconds=elapsed_rounded)) diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 96019680179d4..02408f6ed17e8 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,6 +3,6 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 # RUFF -ruff==0.9.5 +ruff==0.11.6 # CLANGFORMAT clang-format==19.1.7 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index d9ed1deb61e2a..03b51790e0ef6 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -444,6 +444,7 @@ def generate_build_tree( # interface variables are used only for building onnxruntime/onnxruntime_shared.dll but not EPs "-Donnxruntime_USE_TENSORRT_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), "-Donnxruntime_USE_CUDA_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), + "-Donnxruntime_USE_NV_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), "-Donnxruntime_USE_OPENVINO_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), "-Donnxruntime_USE_VITISAI_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), "-Donnxruntime_USE_QNN_INTERFACE=" + ("ON" if args.enable_generic_interface else "OFF"), @@ -2188,18 +2189,6 @@ def main(): cmake_extra_defines = normalize_arg_list(args.cmake_extra_defines) - # When this flag is enabled, it is possible ONNXRuntime shared library is build separately, expecting some compatible EP - # shared lib being build in a separate process. So we skip the testing if none of the primary EPs are built with ONNXRuntime - # shared lib - if args.enable_generic_interface and not ( - args.use_nv_tensorrt_rtx - or args.use_tensorrt - or args.use_openvino - or args.use_vitisai - or (args.use_qnn and args.use_qnn != "static_lib") - ): - args.test = False - if args.use_tensorrt or args.use_nv_tensorrt_rtx: args.use_cuda = True diff --git a/tools/ci_build/coverage.py b/tools/ci_build/coverage.py index 48d919aa7358b..9ec7299927bdd 100644 --- a/tools/ci_build/coverage.py +++ b/tools/ci_build/coverage.py @@ -52,7 +52,7 @@ def adb_shell(*args, **kwargs): adb_shell("cd /data/local/tmp && tar -zcf gcda_files.tar.gz *.dir") adb_pull("/data/local/tmp/gcda_files.tar.gz", cwd) os.chdir(cwd) - run_subprocess("tar -zxf gcda_files.tar.gz -C CMakeFiles".split(" ")) + run_subprocess(["tar", "-zxf", "gcda_files.tar.gz", "-C", "CMakeFiles"]) cmd = ["gcovr", "-s", "-r"] cmd.append(os.path.join(source_dir, "onnxruntime")) cmd.extend([".", "-o"]) diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 122c7651907b0..bf727f0a6aeb0 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -309,14 +309,13 @@ stages: " Repository: onnxruntimeubi8packagestest_torch UseImageCacheContainerRegistry: false - - task: DownloadPackage@1 - displayName: 'Download Meta Llama2 model' inputs: - packageType: upack - feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' - version: 1.0.0 - definition: '6fe0c4ed-9d0e-4d66-94cc-fb6a111d02a5' + packageType: 'upack' + feed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/8fc55c18-5239-4843-a0df-c7a8b1b36be7' + view: 'bdf2fe57-11ef-4a86-a145-7857a0405755' + definition: '2a6a4112-e5b6-48cb-ab68-ddfb83533702' + version: '1.0.0' downloadPath: $(Agent.TempDirectory)/meta_llama2_7b_hf - script: | @@ -421,16 +420,16 @@ stages: ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: '--build-arg BUILD_UID=$( id -u )' Repository: onnxruntimepackagestest_ompffmpeg - - task: DownloadPackage@1 # The model data in artifact is downloaded from openai/whisper-large-v3 in huggingface model hub # In order to save size, removed .git directory and pickled files, and keep the safetensors model files displayName: 'Download Whisper Model' inputs: - packageType: upack - feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' - version: 1.0.0 - definition: 'b583ce7c-1a8f-4099-ae28-5d5f56c478b1' + packageType: 'upack' + feed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/8fc55c18-5239-4843-a0df-c7a8b1b36be7' + view: 'bdf2fe57-11ef-4a86-a145-7857a0405755' + definition: 'c78cdb16-4022-4baa-84d4-8bccbe935762' + version: '1.0.1' downloadPath: $(Agent.TempDirectory)/whisper_large_v3 - script: | diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 2cb64733f6f6c..cf213c47195c4 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -179,7 +179,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-nuget-dml' StageName: 'Windows_CI_GPU_DML_Dev' - BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache + BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache BuildArch: 'x64' msbuildArchitecture: 'amd64' EnvSetupScript: 'setup_env.bat' @@ -199,7 +199,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-win-dml-x86-zip' StageName: 'Windows_CI_GPU_DML_Dev_x86' - BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache + BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache BuildArch: 'x86' EnvSetupScript: 'setup_env_x86.bat' sln_platform: 'Win32' @@ -220,7 +220,7 @@ extends: IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-win-dml-arm64-zip' StageName: 'Windows_CI_GPU_DML_Dev_arm64' - BuildCommand: --build_dir $(Build.BinariesDirectory) --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache + BuildCommand: --build_dir $(Build.BinariesDirectory) --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 7b214dbdfae3a..b1a7c92dc3529 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -14,7 +14,7 @@ parameters: default: false - name: PackageName - displayName: What is the package name? + displayName: What is the package name? Override using an environment variable CustomPackageName. type: string default: 'Microsoft.ML.OnnxRuntime.Flamingo' @@ -72,15 +72,19 @@ extends: DoEsrp: true ArtifactName: 'drop-nuget-qnn-arm64' # Add --use_webgpu to enable WebGPU - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' build_config: 'RelWithDebInfo' PublishArchive: true + - template: templates/mac-cpu-packaging-pipeline.yml + parameters: + AllowReleasedOpsetOnly: 1 + BuildForAllArchs: true + AdditionalBuildFlags: '--use_webgpu --skip_tests' + DoEsrp: true + - stage: NugetPackaging - dependsOn: [Windows_Packaging_CUDA, OnnxRuntime_QNN_Nuget_Win_Arm64] + dependsOn: [Windows_Packaging_CUDA, OnnxRuntime_QNN_Nuget_Win_Arm64, MacOS_C_API_Package_Publish] jobs: - job: CreateNugetPackage pool: 'Onnxruntime-Win2022-GPU-A10' @@ -98,7 +102,7 @@ extends: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - managed nuget' inputs: - artifactName: 'drop-nuget-qnn-arm64' + artifactName: 'drop-signed-nuget-qnn' targetPath: '$(Build.BinariesDirectory)/managed-nuget' - task: DownloadPipelineArtifact@0 @@ -110,9 +114,56 @@ extends: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - win-arm64' inputs: - artifactName: 'onnxruntime-win-ARM64-qnn' + artifactName: 'onnxruntime-win-arm64x-qnn' targetPath: '$(Build.BinariesDirectory)/win-arm64' + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - osx' + inputs: + artifactName: 'onnxruntime-osx' + targetPath: '$(Build.BinariesDirectory)/osx' + + - task: PowerShell@2 + displayName: 'Create osx directories' + inputs: + targetType: 'inline' + script: | + mkdir -p $(Build.BinariesDirectory)/osx-x64 + Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-x86_64* -Destination $(Build.BinariesDirectory)/osx-x64 + + mkdir -p $(Build.BinariesDirectory)/osx-arm64 + Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-arm64* -Destination $(Build.BinariesDirectory)/osx-arm64 + + - task: PowerShell@2 + displayName: 'List all files downloaded' + inputs: + targetType: 'inline' + script: | + $files = Get-ChildItem $(Build.BinariesDirectory) -Recurse + foreach ($file in $files) { + Write-Host "File: $($file.FullName)" + if ($file -like "*onnxruntime*") { + Write-Host "File onnxruntime: $($file.FullName) - Size: $($file.Length)" + } + } + $dirs = Get-ChildItem $(Build.BinariesDirectory) -Directory + foreach ($dir in $dirs) { + Write-Host "Directory: $($dir.FullName)" + } + $osx_x64_archive = Get-ChildItem -Path $(Build.BinariesDirectory)/osx-x64 -Filter onnxruntime-osx-x86_64* + if ($osx_x64_archive.Count -eq 0) { + Write-Host "No osx-x64 archive found." + } else { + Write-Host "osx-x64 archive found: $($osx_x64_archive[0].FullName)" + } + $osx_arm64_archive = Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64* + if ($osx_arm64_archive.Count -eq 0) { + Write-Host "No osx-arm64 archive found." + } else { + Write-Host "osx-arm64 archive found: $($osx_arm64_archive[0].FullName)" + } + workingDirectory: $(Build.BinariesDirectory) + - task: PowerShell@2 displayName: 'Extract Nuget Package Version' inputs: @@ -131,18 +182,40 @@ extends: targetType: 'inline' script: | Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 - Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-ARM64-qnn*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 + Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-arm64x-qnn*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 + $osx_x64_archive = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-x64 -Filter onnxruntime-osx-x86_64*)[0].FullName + $osx_arm64_archive = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64*)[0].FullName + tar -xzf $osx_x64_archive -C $(Build.BinariesDirectory)/osx-x64 2>$null + tar -xzf $osx_arm64_archive -C $(Build.BinariesDirectory)/osx-arm64 2>$null $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Filter onnxruntime-win-x64-cuda*)[0].FullName - $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-ARM64-qnn*)[0].FullName + $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-arm64x-qnn*)[0].FullName + $osx_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-x64 -Filter onnxruntime-osx-x86_64*)[0].FullName + $osx_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64*)[0].FullName Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" + Write-Host "##vso[task.setvariable variable=osx_x64;]$osx_x64" + Write-Host "##vso[task.setvariable variable=osx_arm64;]$osx_arm64" + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Get Package Name' + inputs: + targetType: 'inline' + script: | + if ($env:CustomPackageName) { + Write-Host "##vso[task.setvariable variable=PackageName;]$env:CustomPackageName" + Write-Host "PackageName: $env:CustomPackageName" + } else { + Write-Host "##vso[task.setvariable variable=PackageName;]${{ parameters.PackageName }}" + Write-Host "PackageName: ${{ parameters.PackageName }}" + } workingDirectory: $(Build.BinariesDirectory) - task: PythonScript@0 displayName: 'Generate Nuget Package' inputs: scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' - arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --package_version "$(package_version)" --package_name "${{ parameters.PackageName }}"' + arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --osx_arm64 "$(osx_arm64)" --osx_x64 "$(osx_x64)" --package_version "$(package_version)" --package_name "$(PackageName)"' - task: NuGetCommand@2 displayName: 'Pack Nuget Package' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml deleted file mode 100644 index 2451d39a1bb06..0000000000000 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +++ /dev/null @@ -1,336 +0,0 @@ -# This CI has the following steps: -# 1. Build full ORT, install the full ORT python wheel and use it to generate ort format test models -# and include ops config file for step 3. -# 2. Build minimal ORT including all the kernels and disable exceptions. -# This step is build only to safe-guard the --disable_exceptions option. -# 3. Build minimal ORT include only the kernels using the include ops config file from step 1, -# and the models from /onnxruntime/test/testdata/, run UT, and use onnx_test_runner to -# test the ort format models generated in step 1. -# Exceptions are enabled in this step to help debugging in case of CI failure. -# This step builds and tests ORT with (3a) and without (3b) type reduction enabled. -# 4. Build minimal ORT with type reduction from a globally allowed types list. -# This step uses a hard-coded list of types which may not include the types needed by the models -# in /onnxruntime/test/testdata/, so the tests for those models are skipped. -# 5. Build extended minimal ORT and run tests. -# 6. Build with all optional features disabled and no kernels. -# 6a: regular build with python enabled checks that the exclusions don't break code paths in a full build. -# 6b: minimal build with exceptions and python disabled checks that the exclusions don't break code paths in a -# minimal build. -# 6c: extended minimal build with exceptions and python disabled checks that the exclusions don't break code paths -# in an extended minimal build. -# 7. Build extended minimal ORT with NNAPI, with exceptions/RTTI/ml_ops disabled, for Android(arm64-v8a), -# this safe-guards the extended minimal build with NNAPI EP. -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -jobs: -- job: Linux_CPU_Minimal_Build_E2E - timeoutInMinutes: 120 - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - variables: - ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - test_data_directory: $(Build.SourcesDirectory)/.test_data - - steps: - - - checkout: self - clean: true - submodules: none - - - template: "templates/use-android-ndk.yml" - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile - Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cpu - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" - Repository: onnxruntimecpubuildcentos8x64_packaging - - - task: CmdLine@2 - displayName: Create test data directory - inputs: - script: | - # Create a folder for all test data - mkdir -p $(test_data_directory) - # create empty config used in some parts - touch $(test_data_directory)/include_no_operators.config - workingDirectory: $(Build.SourcesDirectory) - - - template: templates/linux-build-step-with-cache.yml - parameters: - WithCache: true - Today: $(TODAY) - AdditionalKey: onnxruntime__full - CacheDir: $(ORT_CACHE_DIR) - ChangeEveryCommit: true - BuildStep: - - task: CmdLine@2 - displayName: 1. Build full onnxruntime and generate ORT format test files - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - --volume $(ORT_CACHE_DIR):/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - -e ORT_BUILD_WITH_CACHE=1 \ - onnxruntimecpubuildcentos8x64_packaging \ - /bin/bash -c " - set -e -x; - /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh /build/1; \ - ccache -sv; \ - ccache -z;" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 2. Build minimal onnxruntime [exceptions DISABLED, type reduction DISABLED, training ops ENABLED] - inputs: - script: | - # We will try to build minimal ORT with exceptions disabled and training ops enabled - # Only the building process is verified here, no test will be performed - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/2 --cmake_generator Ninja \ - --config Debug \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --skip_tests \ - --minimal_build \ - --disable_exceptions \ - --enable_training_ops" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 3a. Build minimal onnxruntime [exceptions ENABLED, type reduction DISABLED, custom ops ENABLED] and run tests - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh \ - --build-directory /build/3a \ - --reduced-ops-config /home/onnxruntimedev/.test_data/required_ops.ort_models.config \ - --enable-custom-ops - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 3b. Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED] and run tests - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh \ - --build-directory /build/3b \ - --reduced-ops-config /home/onnxruntimedev/.test_data/required_ops_and_types.ort_models.config \ - --enable-type-reduction - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 4. Build minimal onnxruntime [exceptions ENABLED, type reduction ENABLED (globally allowed types)] and run tests - inputs: - script: | - printf "%s\n%s\n" \ - "!globally_allowed_types;bool,float,int8_t,uint8_t" \ - "!no_ops_specified_means_all_ops_are_required" \ - > $(test_data_directory)/globally_allowed_types.config && \ - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh \ - --build-directory /build/4 \ - --reduced-ops-config /home/onnxruntimedev/.test_data/globally_allowed_types.config \ - --enable-type-reduction \ - --skip-model-tests - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 5. Build extended minimal onnxruntime and run tests - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/5 --cmake_generator Ninja \ - --config Debug \ - --skip_submodule_sync \ - --build_shared_lib --use_binskim_compliant_compile_flags \ - --parallel \ - --minimal_build extended" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 6a. Regular build with python and all optional features disabled. - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/6a \ - --cmake_generator Ninja \ - --config MinSizeRel \ - --skip_submodule_sync \ - --build_shared_lib \ - --build_wheel \ - --parallel --use_binskim_compliant_compile_flags \ - --skip_tests \ - --disable_ml_ops \ - --disable_types sparsetensor float8 optional \ - --include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config \ - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 6b. Minimal build with all optional features disabled. - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/6b \ - --cmake_generator Ninja \ - --config MinSizeRel \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --minimal_build \ - --disable_exceptions \ - --disable_ml_ops \ - --skip_tests \ - --enable_reduced_operator_type_support \ - --disable_types sparsetensor optional float8 \ - --include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config \ - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 6c. Extended minimal build with all optional features disabled. - inputs: - script: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $(test_data_directory):/home/onnxruntimedev/.test_data \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/6c \ - --cmake_generator Ninja \ - --config MinSizeRel \ - --skip_submodule_sync \ - --build_shared_lib --use_binskim_compliant_compile_flags \ - --parallel \ - --minimal_build extended \ - --disable_exceptions \ - --disable_ml_ops \ - --skip_tests \ - --enable_reduced_operator_type_support \ - --disable_types sparsetensor optional float8 \ - --include_ops_by_config /home/onnxruntimedev/.test_data/include_no_operators.config \ - --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF" - workingDirectory: $(Build.SourcesDirectory) - - - task: CmdLine@2 - displayName: 7. Extended minimal build with NNAPI EP for Android(arm64-v8a) and skip tests. - inputs: - script: | - NDK_HOME=$(realpath $ANDROID_NDK_HOME) - docker run -e SYSTEM_COLLECTIONURI --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume $ANDROID_HOME:/android_home \ - --volume $NDK_HOME:/ndk_home \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD \ - onnxruntimecpubuildcentos8x64_packaging \ - bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build/7 \ - --cmake_generator Ninja \ - --config MinSizeRel \ - --skip_submodule_sync \ - --parallel --use_binskim_compliant_compile_flags \ - --android \ - --android_sdk_path /android_home \ - --android_ndk_path /ndk_home \ - --android_abi=arm64-v8a \ - --android_api=29 \ - --use_nnapi \ - --minimal_build extended \ - --build_shared_lib \ - --disable_ml_ops \ - --disable_exceptions \ - --skip_tests" - workingDirectory: $(Build.SourcesDirectory) - - - template: templates/explicitly-defined-final-tasks.yml diff --git a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml index c6ab33164035c..753395151b620 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml @@ -260,12 +260,12 @@ extends: displayName: "Sign Nuget package" inputs: ConnectedServiceName: 'OnnxrunTimeCodeSign_20240611' - AppRegistrationClientId: '53d54d02-978d-4305-8572-583cf6711c4f' - AppRegistrationTenantId: '72f988bf-86f1-41af-91ab-2d7cd011db47' - AuthAKVName: 'buildkeyvault' - AuthCertName: '53d54d02-SSL-AutoRotate' - AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' - + UseMSIAuthentication: true + AppRegistrationClientId: '62b7cfed-4d25-454f-880e-010dc21455ac' + AppRegistrationTenantId: '975f013f-7f24-47e8-a7d3-abc4752bf346' + EsrpClientId: "53d54d02-978d-4305-8572-583cf6711c4f" + AuthAKVName: 'ortbuildkeyvault' + AuthSignCertName: 'esrpcodesign' FolderPath: $(Build.ArtifactStagingDirectory) Pattern: '*.nupkg' SessionTimeout: 90 diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 34fbe74260ace..722a3162cfed8 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -53,27 +53,10 @@ extends: QnnSdk: ${{ parameters.QnnSdk }} IsReleaseBuild: ${{ parameters.IsReleaseBuild }} DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-x64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' + ArtifactName: 'drop-nuget-qnn-arm64x' + StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64x' build_config: ${{ parameters.build_config }} - - template: templates/qnn-ep-win.yml - parameters: - qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' - QnnSdk: ${{ parameters.QnnSdk }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64' - buildParameter: '--arm64' - buildPlatform: 'ARM64' - buildArch: 'ARM64' - StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64' - build_config: ${{ parameters.build_config }} - - - template: stages/nuget-qnn-packaging-stage.yml - parameters: - DoEsrp: ${{ parameters.DoEsrp }} - - template: templates/publish-nuget-steps.yml parameters: download_artifacts_steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml index 093de22566a8b..0f9314dbbedfb 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml @@ -55,6 +55,7 @@ stages: move win-x86\runtimes\win-x86\native\onnxruntime.dll %%~ni\runtimes\win-x86\native\onnxruntime.dll move win-x86\runtimes\win-x86\native\onnxruntime.lib %%~ni\runtimes\win-x86\native\onnxruntime.lib move win-x86\runtimes\win-x86\native\onnxruntime.pdb %%~ni\runtimes\win-x86\native\onnxruntime.pdb + move win-x86\runtimes\win-x86\native\onnxruntime_providers_shared.dll %%~ni\runtimes\win-x86\native\onnxruntime_providers_shared.dll unzip win-dml-arm64.zip -d win-arm64 mkdir %%~ni\runtimes\win-arm64 @@ -63,6 +64,7 @@ stages: move win-arm64\runtimes\win-arm64\native\onnxruntime.dll %%~ni\runtimes\win-arm64\native\onnxruntime.dll move win-arm64\runtimes\win-arm64\native\onnxruntime.lib %%~ni\runtimes\win-arm64\native\onnxruntime.lib move win-arm64\runtimes\win-arm64\native\onnxruntime.pdb %%~ni\runtimes\win-arm64\native\onnxruntime.pdb + move win-arm64\runtimes\win-arm64\native\onnxruntime_providers_shared.dll %%~ni\runtimes\win-arm64\native\onnxruntime_providers_shared.dll pushd %%~ni diff --git a/tools/ci_build/github/azure-pipelines/templates/esrp_nuget.yml b/tools/ci_build/github/azure-pipelines/templates/esrp_nuget.yml index 79cceb7a02511..ffec479474721 100644 --- a/tools/ci_build/github/azure-pipelines/templates/esrp_nuget.yml +++ b/tools/ci_build/github/azure-pipelines/templates/esrp_nuget.yml @@ -9,12 +9,12 @@ steps: displayName: 'ESRP CodeSigning' inputs: ConnectedServiceName: 'OnnxrunTimeCodeSign_20240611' - AppRegistrationClientId: '53d54d02-978d-4305-8572-583cf6711c4f' - AppRegistrationTenantId: '72f988bf-86f1-41af-91ab-2d7cd011db47' - AuthAKVName: 'buildkeyvault' - AuthCertName: '53d54d02-SSL-AutoRotate' - AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' - + UseMSIAuthentication: true + AppRegistrationClientId: '62b7cfed-4d25-454f-880e-010dc21455ac' + AppRegistrationTenantId: '975f013f-7f24-47e8-a7d3-abc4752bf346' + EsrpClientId: "53d54d02-978d-4305-8572-583cf6711c4f" + AuthAKVName: 'ortbuildkeyvault' + AuthSignCertName: 'esrpcodesign' FolderPath: ${{ parameters.FolderPath }} Pattern: '*.nupkg' SessionTimeout: 90 diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml index d14952e544e5e..df2aff0634819 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -7,7 +7,7 @@ steps: displayName: 'Get GnuPG signing keys' inputs: #The value below is the name of an ADO service connection. - azureSubscription: 'OnnxrunTimeCodeSign_20240611' + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' KeyVaultName: 'ort-release' SecretsFilter: 'java-pgp-pwd,java-pgp-key' RunAsPreJob: false diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml index 5681b3568bae1..ef845dc3bf243 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml @@ -6,7 +6,7 @@ steps: - task: AzureKeyVault@2 displayName: 'Get GnuPG signing keys' inputs: - azureSubscription: 'OnnxrunTimeCodeSign_20240611' + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' KeyVaultName: 'ort-release' SecretsFilter: 'java-pgp-pwd,java-pgp-key' RunAsPreJob: false diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 8fcdab437052c..7547b841c7480 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -107,12 +107,12 @@ stages: ls -l popd displayName: tgz to zip + - template: mac-esrp-dylib.yml parameters: FolderPath: '$(Build.ArtifactStagingDirectory)' - DisplayName: 'ESRP - Sign Mac' - DoEsrp: true Pattern: '*.zip' + - script: | pushd '$(Build.ArtifactStagingDirectory)' find . '*.zip' -exec unzip {} \; @@ -136,5 +136,3 @@ stages: targetPath: '$(Build.ArtifactStagingDirectory)' artifactName: 'onnxruntime-osx' condition: 'succeededOrFailed()' - - diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-esrp-dylib.yml b/tools/ci_build/github/azure-pipelines/templates/mac-esrp-dylib.yml index aeebf2a39c8e0..5e6cd2240feba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-esrp-dylib.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-esrp-dylib.yml @@ -1,16 +1,8 @@ parameters: -- name: DoEsrp - type: boolean - default: true - - name: FolderPath type: string default: '' -- name: DisplayName - type: string - default: '' - - name: Pattern type: string default: '*.zip' @@ -20,14 +12,14 @@ steps: displayName: 'ESRP CodeSigning' inputs: ConnectedServiceName: 'OnnxrunTimeCodeSign_20240611' - AppRegistrationClientId: '53d54d02-978d-4305-8572-583cf6711c4f' - AppRegistrationTenantId: '72f988bf-86f1-41af-91ab-2d7cd011db47' - AuthAKVName: 'buildkeyvault' - AuthCertName: '53d54d02-SSL-AutoRotate' - AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' - + UseMSIAuthentication: true + AppRegistrationClientId: '62b7cfed-4d25-454f-880e-010dc21455ac' + AppRegistrationTenantId: '975f013f-7f24-47e8-a7d3-abc4752bf346' + EsrpClientId: "53d54d02-978d-4305-8572-583cf6711c4f" + AuthAKVName: 'ortbuildkeyvault' + AuthSignCertName: 'esrpcodesign' FolderPath: ${{ parameters.FolderPath }} - Pattern: '*.nupkg' + Pattern: ${{ parameters.Pattern }} SessionTimeout: 90 ServiceEndpointUrl: 'https://api.esrp.microsoft.com/api/v2' MaxConcurrency: 25 diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 11560486dfd6c..e4888ffd62df3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -5,10 +5,7 @@ parameters: DoEsrp: false qnn_ep_build_pool_name: 'Onnxruntime-QNNEP-Windows-2022-CPU' ArtifactName: 'drop-nuget-qnn' - buildParameter: '' OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime.QNN' - buildPlatform: 'x64' - buildArch: 'x64' StageName: 'OnnxRuntime_QNN_Nuget_Win_x64' PublishArchive: false @@ -17,17 +14,13 @@ stages: dependsOn: [] jobs: - job: ${{ parameters.StageName }} - timeoutInMinutes: 120 + timeoutInMinutes: 300 pool: name: ${{ parameters.qnn_ep_build_pool_name }} variables: - ${{ if eq(parameters.buildArch, 'ARM64') }}: - targetArchitecture: 'arm64' - ${{ else }}: - targetArchitecture: ${{ parameters.buildArch }} OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--update --compile_no_warning_as_error --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ${{ parameters.buildParameter }} ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml @@ -42,58 +35,16 @@ stages: QnnSDKVersion: ${{ parameters.QnnSdk }} - task: PythonScript@0 - displayName: 'Generate project' + displayName: 'Build arm64x project - generate the def & lib file for next build' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - - - task: VSBuild@1 - displayName: 'Build onnxruntime' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnx_test_runner' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnx_test_runner.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_perf_test' - inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_perf_test.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true - - - task: VSBuild@1 - displayName: 'Build onnxruntime_test_all (to copy Qnn libs)' + arguments: ' --arm64 --buildasx --build_dir $(Build.BinariesDirectory)\arm64x --use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' + + - task: PythonScript@0 + displayName: 'Build arm64ecx project - the real arm64x' inputs: - solution: '$(Build.BinariesDirectory)\${{ parameters.build_config }}\onnxruntime_test_all.vcxproj' - platform: ${{ parameters.buildPlatform }} - configuration: ${{ parameters.build_config }} - msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\${{ parameters.build_config }}' - createLogFile: true + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: ' --arm64ec --buildasx --build_dir $(Build.BinariesDirectory) --use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - task: CmdLine@2 displayName: 'Print contents of binaries directory' @@ -112,8 +63,8 @@ stages: - template: c-api-artifacts-package-and-publish-steps-windows.yml parameters: buildConfig: ${{ parameters.build_config }} - artifactName: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' - artifactNameNoVersionString: 'onnxruntime-win-${{ parameters.buildPlatform }}-qnn' + artifactName: 'onnxruntime-win-arm64x-qnn' + artifactNameNoVersionString: 'onnxruntime-win-arm64x-qnn' DoEsrp: ${{ parameters.DoEsrp }} - task: MSBuild@1 @@ -147,7 +98,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=$(targetArchitecture)' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64x' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 @@ -164,8 +115,14 @@ stages: Contents: '*.snupkg' TargetFolder: '$(Build.ArtifactStagingDirectory)' + - template: ../templates/esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} + - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline x64 NuGet Artifact' + displayName: 'Publish Pipeline Qnn NuGet Artifact' inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' + artifactName: 'drop-signed-nuget-qnn' + targetPath: '$(Build.ArtifactStagingDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml index 86acebc9f7a71..0476bc74349bf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml @@ -21,11 +21,12 @@ steps: condition: and(succeeded(), eq('${{ parameters.DoEsrp }}', true)) inputs: ConnectedServiceName: 'OnnxrunTimeCodeSign_20240611' - AppRegistrationClientId: '53d54d02-978d-4305-8572-583cf6711c4f' - AppRegistrationTenantId: '72f988bf-86f1-41af-91ab-2d7cd011db47' - AuthAKVName: 'buildkeyvault' - AuthCertName: '53d54d02-SSL-AutoRotate' - AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' + UseMSIAuthentication: true + AppRegistrationClientId: '62b7cfed-4d25-454f-880e-010dc21455ac' + AppRegistrationTenantId: '975f013f-7f24-47e8-a7d3-abc4752bf346' + EsrpClientId: "53d54d02-978d-4305-8572-583cf6711c4f" + AuthAKVName: 'ortbuildkeyvault' + AuthSignCertName: 'esrpcodesign' signConfigType: inlineSignParams inlineOperation: | [ diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index bc4e0de149b54..93a9909e529f8 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -50,13 +50,16 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - qnn_build_args: '--use_qnn' + ExtraQnnBuildArgs: '' STATIC_LIB: QnnLibKind: 'static_lib' - qnn_build_args: '--use_qnn' + ExtraQnnBuildArgs: '' SHARED_LIB_GENERIC_INTERFACE: QnnLibKind: 'shared_lib' - qnn_build_args: '--use_qnn --use_generic_interface' + # Note: Building ORT with generic ep interface which only builds the provider-bridge APIs for + # various EPs, but does not build the actual EPs. We enable --build_wheel for additional code coverage + # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. + ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - script: | @@ -93,7 +96,7 @@ jobs: --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) - --update --build --parallel + --update --build --parallel $(ExtraQnnBuildArgs) - script: | python $(Build.SourcesDirectory)\tools\ci_build\build.py ^ @@ -103,7 +106,7 @@ jobs: --build_shared_lib ^ --use_qnn $(QnnLibKind) ^ --qnn_home $(QnnSDKRootDir) ^ - --test --enable_onnx_tests + --test --enable_onnx_tests $(ExtraQnnBuildArgs) displayName: 'Run unit tests' - script: | diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index f9d84e3b0e130..5db0e32e0df8b 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -8,7 +8,7 @@ ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index d94e7562f19d4..6052096877ac5 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -8,7 +8,7 @@ ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh index d647cec3ba020..daa6966357188 100755 --- a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +++ b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh @@ -8,17 +8,17 @@ set -x BUILD_DIR=${1:?"usage: $0 "} -python3 -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt +python3 -m pip install --user -r tools/ci_build/github/linux/python/requirements.txt # Validate the operator kernel registrations, as the ORT model uses hashes of the kernel registration details # to find kernels. If the hashes from the registration details are incorrect we will produce a model that will break # when the registration is fixed in the future. -python3 /onnxruntime_src/tools/ci_build/op_registration_validator.py +python3 tools/ci_build/op_registration_validator.py # Run a full build of ORT. # We need the ORT python package to generate the ORT format files and the required ops config files. # We do not run tests in this command since those are covered by other CIs. # Both the NNAPI and CoreML EPs are enabled. -python3 /onnxruntime_src/tools/ci_build/build.py \ +python3 tools/ci_build${BUILD_DIR}.py \ --build_dir ${BUILD_DIR} --cmake_generator Ninja \ --config Debug \ --skip_submodule_sync \ @@ -33,38 +33,38 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ python3 -m pip install --user ${BUILD_DIR}/Debug/dist/* # Convert all the E2E ONNX models to ORT format -python3 /onnxruntime_src/tools/python/convert_onnx_models_to_ort.py \ - /onnxruntime_src/onnxruntime/test/testdata/ort_minimal_e2e_test_data +python3 tools/python/convert_onnx_models_to_ort.py \ + onnxruntime/test/testdata/ort_minimal_e2e_test_data # Do it again using the conversion script from the python package to validate that also works python3 -m onnxruntime.tools.convert_onnx_models_to_ort \ - /onnxruntime_src/onnxruntime/test/testdata/ort_minimal_e2e_test_data + onnxruntime/test/testdata/ort_minimal_e2e_test_data # Create configs with just the required ops for ORT format models in testdata # These are used by build_minimal_ort_and_run_tests.sh later in the linux-cpu-minimal-build-ci-pipeline CI # and will include ops for the E2E models we just converted # Config without type reduction -python3 /onnxruntime_src/tools/python/create_reduced_build_config.py --format ORT \ - /onnxruntime_src/onnxruntime/test/testdata \ - /home/onnxruntimedev/.test_data/required_ops.ort_models.config +python3 tools/python/create_reduced_build_config.py --format ORT \ + onnxruntime/test/testdata \ + ${BUILD_DIR}/.test_data/required_ops.ort_models.config # Config with type reduction -python3 /onnxruntime_src/tools/python/create_reduced_build_config.py --format ORT --enable_type_reduction \ - /onnxruntime_src/onnxruntime/test/testdata \ - /home/onnxruntimedev/.test_data/required_ops_and_types.ort_models.config +python3 tools/python/create_reduced_build_config.py --format ORT --enable_type_reduction \ + onnxruntime/test/testdata \ + ${BUILD_DIR}/.test_data/required_ops_and_types.ort_models.config # Append the info for ops involved from inside custom ops. These can't be read from the models as they're # dynamically created at runtime when the kernel is created. -cat /onnxruntime_src/onnxruntime/test/testdata/ort_minimal_e2e_test_data/required_ops.standalone_invoker.config >> \ - /home/onnxruntimedev/.test_data/required_ops.ort_models.config -cat /onnxruntime_src/onnxruntime/test/testdata/ort_minimal_e2e_test_data/required_ops.standalone_invoker.config >> \ - /home/onnxruntimedev/.test_data/required_ops_and_types.ort_models.config +cat onnxruntime/test/testdata/ort_minimal_e2e_test_data/required_ops.standalone_invoker.config >> \ + ${BUILD_DIR}/.test_data/required_ops.ort_models.config +cat onnxruntime/test/testdata/ort_minimal_e2e_test_data/required_ops.standalone_invoker.config >> \ + ${BUILD_DIR}/.test_data/required_ops_and_types.ort_models.config # Test that we can convert an ONNX model with custom ops to ORT format -mkdir /home/onnxruntimedev/.test_data/custom_ops_model -cp /onnxruntime_src/onnxruntime/test/testdata/custom_op_library/*.onnx /home/onnxruntimedev/.test_data/custom_ops_model/ -python3 /onnxruntime_src/tools/python/convert_onnx_models_to_ort.py \ +mkdir ${BUILD_DIR}/.test_data/custom_ops_model +cp onnxruntime/test/testdata/custom_op_library/*.onnx ${BUILD_DIR}/.test_data/custom_ops_model/ +python3 tools/python/convert_onnx_models_to_ort.py \ --custom_op_library ${BUILD_DIR}/Debug/libcustom_op_library.so \ - /home/onnxruntimedev/.test_data/custom_ops_model -rm -rf /home/onnxruntimedev/.test_data/custom_ops_model + ${BUILD_DIR}/.test_data/custom_ops_model +rm -rf ${BUILD_DIR}/.test_data/custom_ops_model diff --git a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh deleted file mode 100755 index f5184b20d0a6c..0000000000000 --- a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash - -# This script will create a minimal build with the required operators for all ORT format models -# in the testdata directory. This includes E2E models generated by build_full_ort_and_create_ort_files.sh. -# The build will run the unit tests for the minimal build, followed by running onnx_test_runner -# for the E2E test cases. - -set -e -set -x - -USAGE_TEXT="Usage: - -b|--build-directory - Specifies the build directory. Required. - -c|--reduced-ops-config - Specifies the reduced Ops configuration file path. Required. - [--enable-type-reduction] - Builds with type reduction enabled. - [--enable-custom-ops] - Builds with custom op support enabled. - [--skip-model-tests] - Does not run the E2E model tests." - -BUILD_DIR= -REDUCED_OPS_CONFIG_FILE= -ENABLE_TYPE_REDUCTION= -MINIMAL_BUILD_ARGS= -SKIP_MODEL_TESTS= - -while [[ $# -gt 0 ]] -do - OPTION_KEY="$1" - case $OPTION_KEY in - -b|--build-directory) - BUILD_DIR="$2" - shift - shift - ;; - -c|--reduced-ops-config) - REDUCED_OPS_CONFIG_FILE="$2" - shift - shift - ;; - --enable-type-reduction) - ENABLE_TYPE_REDUCTION=1 - shift - ;; - --enable-custom-ops) - MINIMAL_BUILD_ARGS="custom_ops" - shift - ;; - --skip-model-tests) - SKIP_MODEL_TESTS=1 - shift - ;; - *) - echo "Invalid option: $1" - echo "$USAGE_TEXT" - exit 1 - ;; - esac -done - -if [[ -z "${BUILD_DIR}" || -z "${REDUCED_OPS_CONFIG_FILE}" ]]; then - echo "Required option was not provided." - echo "$USAGE_TEXT" - exit 1 -fi -python3 -m pip install -r /onnxruntime_src/tools/ci_build/github/linux/python/requirements.txt -# Perform a minimal build with required ops and run ORT minimal build UTs -python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir ${BUILD_DIR} --cmake_generator Ninja \ - --config Debug \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --minimal_build ${MINIMAL_BUILD_ARGS} \ - --disable_ml_ops \ - --include_ops_by_config ${REDUCED_OPS_CONFIG_FILE} \ - ${ENABLE_TYPE_REDUCTION:+"--enable_reduced_operator_type_support"} - -if [[ -z "${SKIP_MODEL_TESTS}" ]]; then - # Run the e2e model test cases - ${BUILD_DIR}/Debug/onnx_test_runner /onnxruntime_src/onnxruntime/test/testdata/ort_minimal_e2e_test_data -fi - -# Print binary size info -python3 /onnxruntime_src/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py \ - --arch "$(uname -m)" --os "$(uname -o)" --build_config "minimal-reduced" \ - ${BUILD_DIR}/Debug/libonnxruntime.so - -echo "Binary size info:" -cat ${BUILD_DIR}/Debug/binary_size_data.txt diff --git a/tools/nuget/generate_nuspec_for_custom_nuget.py b/tools/nuget/generate_nuspec_for_custom_nuget.py index baf46743cbf1b..4421bf3ba56b5 100644 --- a/tools/nuget/generate_nuspec_for_custom_nuget.py +++ b/tools/nuget/generate_nuspec_for_custom_nuget.py @@ -14,6 +14,8 @@ def generate_files(lines, args): platform_map = { "win-arm64": args.win_arm64, "win-x64": args.win_x64, + "osx-x64": args.osx_x64, + "osx-arm64": args.osx_arm64, } avoid_keywords = {"pdb"} @@ -112,6 +114,8 @@ def parse_arguments(): ) parser.add_argument("--win_arm64", required=True, help="Ort win-arm64 directory") parser.add_argument("--win_x64", required=True, help="Ort win-x64 directory") + parser.add_argument("--osx_arm64", required=True, help="Ort osx-arm64 directory") + parser.add_argument("--osx_x64", required=True, help="Ort osx-x64 directory") parser.add_argument("--package_version", required=True, help="Version of the package") parser.add_argument("--package_name", required=True, help="Name of the package") diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 1655529a5078d..c5a204b6cb958 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -543,18 +543,9 @@ def generate_files(line_list, args): + '" target="lib\\net5.0\\Microsoft.AI.MachineLearning.Interop.pdb" />' ) - if args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" or args.package_name == "Microsoft.ML.OnnxRuntime.QNN": - files_list.append( - "" - ) - files_list.append( - "" - ) - if is_qnn_package: files_list.append("") files_list.append("") - files_list.append("") if args.target_architecture != "x64": files_list.append( "" @@ -574,12 +565,6 @@ def generate_files(line_list, args): files_list.append( "" ) - files_list.append( - "" - ) - files_list.append( - "" - ) is_ado_packaging_build = False # Process runtimes @@ -800,6 +785,16 @@ def generate_files(line_list, args): + '\\native" />' ) + if is_dml_package: + files_list.append( + "' + ) + # process all other library dependencies if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: # Process dnnl dependency @@ -910,9 +905,20 @@ def generate_files(line_list, args): or is_qnn_package ): # Process props file - source_props = os.path.join( - args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", "netstandard", "props.xml" - ) + if is_qnn_package: + source_props = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + "props_qnn.xml", + ) + else: + source_props = os.path.join( + args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", "netstandard", "props.xml" + ) target_props = os.path.join( args.sources_path, "csharp",