Skip to content

Commit c6d79f5

Browse files
ianaylreble
authored andcommitted
Implement USM prefetch from device to host in SYCL runtime and UR (#19437)
Add the ability to control USM prefetch direction (host-to-device, device-to-host) in the enqueue_function extension: ```cpp sycl::ext::oneapi::experimental { enum class prefetch_type { device, host }; void prefetch(sycl::queue q, void* ptr, size_t numBytes, prefetch_type type = prefetch_type::device); void prefetch(sycl::handler &h, void* ptr, size_t numBytes, prefetch_type type = prefetch_type::device); } ``` **Note:** - ~~There is a test failure regarding a new ABI symbol: In order to not break the ABI, I added a new handler function to represent prefetch from device-to-host. Despite the precommit failures, this should not be an ABI-breaking change and thus be okay to merge.~~ I modified the ABI symbols tests as a part of this PR. --------- Co-authored-by: Pablo Reble <[email protected]>
1 parent 91860de commit c6d79f5

File tree

24 files changed

+273
-139
lines changed

24 files changed

+273
-139
lines changed

include/ur_api.h

Lines changed: 7 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

include/ur_print.hpp

Lines changed: 20 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

scripts/core/enqueue.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -915,13 +915,16 @@ etors:
915915
value: "$X_BIT(2)"
916916
--- #--------------------------------------------------------------------------
917917
type: enum
918-
desc: "Map flags"
919-
class: $xDevice
918+
desc: "USM migration flags, indicating the direction data is migrated in"
919+
class: $xEnqueue
920920
name: $x_usm_migration_flags_t
921921
etors:
922-
- name: DEFAULT
923-
desc: "Default migration TODO: Add more enums! "
922+
- name: HOST_TO_DEVICE
923+
desc: "Migrate data from host to device"
924924
value: "$X_BIT(0)"
925+
- name: DEVICE_TO_HOST
926+
desc: "Migrate data from device to host"
927+
value: "$X_BIT(1)"
925928
--- #--------------------------------------------------------------------------
926929
type: function
927930
desc: "Enqueue a command to map a region of the buffer object into the host address space and return a pointer to the mapped region"

scripts/core/exp-command-buffer.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ params:
10251025
desc: "[in] size in bytes to be fetched."
10261026
- type: $x_usm_migration_flags_t
10271027
name: flags
1028-
desc: "[in] USM prefetch flags"
1028+
desc: "[in] USM migration flags"
10291029
- type: uint32_t
10301030
name: numSyncPointsInWaitList
10311031
desc: "[in] The number of sync points in the provided dependency list."

source/adapters/cuda/enqueue.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,14 +1516,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
15161516

15171517
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
15181518
ur_queue_handle_t hQueue, const void *pMem, size_t size,
1519-
ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList,
1519+
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
15201520
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
15211521

1522+
ur_device_handle_t Device = hQueue->getDevice();
1523+
#if CUDA_VERSION >= 13000
1524+
CUmemLocation Location;
1525+
switch (flags) {
1526+
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
1527+
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
1528+
Location.id = Device->get();
1529+
break;
1530+
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
1531+
Location.type = CU_MEM_LOCATION_TYPE_HOST;
1532+
break;
1533+
#else
1534+
int dstDevice;
1535+
switch (flags) {
1536+
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
1537+
dstDevice = Device->get();
1538+
break;
1539+
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
1540+
dstDevice = CU_DEVICE_CPU;
1541+
break;
1542+
#endif
1543+
default:
1544+
setErrorMessage("Invalid USM migration flag",
1545+
UR_RESULT_ERROR_INVALID_ENUMERATION);
1546+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1547+
}
1548+
15221549
size_t PointerRangeSize = 0;
15231550
UR_CHECK_ERROR(cuPointerGetAttribute(
15241551
&PointerRangeSize, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)pMem));
15251552
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
1526-
ur_device_handle_t Device = hQueue->getDevice();
15271553

15281554
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
15291555
try {
@@ -1564,15 +1590,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
15641590
}
15651591

15661592
#if CUDA_VERSION >= 13000
1567-
CUmemLocation Location;
1568-
Location.id = Device->get();
1569-
Location.type = CU_MEM_LOCATION_TYPE_DEVICE;
15701593
unsigned int Flags = 0U;
15711594
UR_CHECK_ERROR(
15721595
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Location, Flags, CuStream));
15731596
#else
15741597
UR_CHECK_ERROR(
1575-
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream));
1598+
cuMemPrefetchAsync((CUdeviceptr)pMem, size, dstDevice, CuStream));
15761599
#endif
15771600
} catch (ur_result_t Err) {
15781601
return Err;

source/adapters/hip/enqueue.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,11 +1324,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13241324

13251325
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13261326
ur_queue_handle_t hQueue, const void *pMem, size_t size,
1327-
ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList,
1327+
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
13281328
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
13291329

1330-
void *HIPDevicePtr = const_cast<void *>(pMem);
13311330
ur_device_handle_t Device = hQueue->getDevice();
1331+
hipDevice_t TargetDevice;
1332+
switch (flags) {
1333+
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
1334+
TargetDevice = Device->get();
1335+
break;
1336+
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
1337+
TargetDevice = hipCpuDeviceId;
1338+
break;
1339+
default:
1340+
setErrorMessage("Invalid USM migration flag",
1341+
UR_RESULT_ERROR_INVALID_ENUMERATION);
1342+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1343+
}
1344+
void *HIPDevicePtr = const_cast<void *>(pMem);
13321345

13331346
// HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5,
13341347
// so we can't perform this check for such cases.
@@ -1385,8 +1398,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13851398
return UR_RESULT_SUCCESS;
13861399
}
13871400

1388-
UR_CHECK_ERROR(
1389-
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
1401+
UR_CHECK_ERROR(hipMemPrefetchAsync(pMem, size, TargetDevice, HIPStream));
13901402
releaseEvent();
13911403
} catch (ur_result_t Err) {
13921404
return Err;

source/adapters/level_zero/adapter.cpp

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -506,72 +506,71 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
506506
bool forceLoadedAdapter = ur_getenv("UR_ADAPTERS_FORCE_LOAD").has_value();
507507
if (!forceLoadedAdapter) {
508508
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
509-
auto [useV2, reason] = shouldUseV2Adapter();
510-
if (!useV2) {
511-
UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason);
512-
return;
513-
}
509+
auto [useV2, reason] = shouldUseV2Adapter();
510+
if (!useV2) {
511+
UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason);
512+
return;
513+
}
514514
#else
515-
auto [useV1, reason] = shouldUseV1Adapter();
516-
if (!useV1) {
517-
UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason);
518-
return;
519-
}
515+
auto [useV1, reason] = shouldUseV1Adapter();
516+
if (!useV1) {
517+
UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason);
518+
return;
519+
}
520520
#endif
521521
}
522522

523-
// Check if the user has enabled the default L0 SysMan initialization.
524-
const int UrSysmanZesinitEnable = [&UserForcedSysManInit] {
525-
const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT");
526-
if (!UrRet)
527-
return 0;
528-
UserForcedSysManInit &= 2;
529-
return std::atoi(UrRet);
530-
}();
531-
532-
bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled;
533-
// Unless the user has forced the SysMan init, we will check the device
534-
// version to see if the zesInit is needed.
535-
if (UserForcedSysManInit == 0 && checkDeviceIntelGPUIpVersionOrNewer(
536-
0x05004000) == UR_RESULT_SUCCESS) {
537-
if (UrSysManEnvInitEnabled) {
538-
setEnvVar("ZES_ENABLE_SYSMAN", "0");
539-
}
540-
ZesInitNeeded = true;
541-
}
542-
if (ZesInitNeeded) {
523+
// Check if the user has enabled the default L0 SysMan initialization.
524+
const int UrSysmanZesinitEnable = [&UserForcedSysManInit] {
525+
const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT");
526+
if (!UrRet)
527+
return 0;
528+
UserForcedSysManInit &= 2;
529+
return std::atoi(UrRet);
530+
}();
531+
532+
bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled;
533+
// Unless the user has forced the SysMan init, we will check the device
534+
// version to see if the zesInit is needed.
535+
if (UserForcedSysManInit == 0 &&
536+
checkDeviceIntelGPUIpVersionOrNewer(0x05004000) == UR_RESULT_SUCCESS) {
537+
if (UrSysManEnvInitEnabled) {
538+
setEnvVar("ZES_ENABLE_SYSMAN", "0");
539+
}
540+
ZesInitNeeded = true;
541+
}
542+
if (ZesInitNeeded) {
543543
#ifdef UR_STATIC_LEVEL_ZERO
544-
getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
545-
getSysManDriversFunctionPtr = zesDriverGet;
546-
sysManInitFunctionPtr = zesInit;
544+
getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp;
545+
getSysManDriversFunctionPtr = zesDriverGet;
546+
sysManInitFunctionPtr = zesInit;
547547
#else
548-
getDeviceByUUIdFunctionPtr = (zes_pfnDriverGetDeviceByUuidExp_t)
549-
ur_loader::LibLoader::getFunctionPtr(processHandle,
550-
"zesDriverGetDeviceByUuidExp");
551-
getSysManDriversFunctionPtr =
552-
(zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr(
553-
processHandle, "zesDriverGet");
554-
sysManInitFunctionPtr =
555-
(zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(processHandle,
556-
"zesInit");
548+
getDeviceByUUIdFunctionPtr =
549+
(zes_pfnDriverGetDeviceByUuidExp_t)ur_loader::LibLoader::getFunctionPtr(
550+
processHandle, "zesDriverGetDeviceByUuidExp");
551+
getSysManDriversFunctionPtr =
552+
(zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr(
553+
processHandle, "zesDriverGet");
554+
sysManInitFunctionPtr = (zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(
555+
processHandle, "zesInit");
557556
#endif
558-
}
559-
if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr &&
560-
sysManInitFunctionPtr) {
561-
ze_init_flags_t L0ZesInitFlags = 0;
562-
UR_LOG(DEBUG, "\nzesInit with flags value of {}\n",
563-
static_cast<int>(L0ZesInitFlags));
564-
ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags));
565-
} else {
566-
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
567-
}
557+
}
558+
if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr &&
559+
sysManInitFunctionPtr) {
560+
ze_init_flags_t L0ZesInitFlags = 0;
561+
UR_LOG(DEBUG, "\nzesInit with flags value of {}\n",
562+
static_cast<int>(L0ZesInitFlags));
563+
ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags));
564+
} else {
565+
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
566+
}
568567

569-
ur_result_t err = initPlatforms(this, platforms, ZesResult);
570-
if (err == UR_RESULT_SUCCESS) {
571-
Platforms = std::move(platforms);
572-
} else {
573-
throw err;
574-
}
568+
ur_result_t err = initPlatforms(this, platforms, ZesResult);
569+
if (err == UR_RESULT_SUCCESS) {
570+
Platforms = std::move(platforms);
571+
} else {
572+
throw err;
573+
}
575574
}
576575

577576
void globalAdapterOnDemandCleanup() {

source/adapters/level_zero/command_buffer.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ ur_result_t urCommandBufferAppendMemBufferReadRectExp(
13131313

13141314
ur_result_t urCommandBufferAppendUSMPrefetchExp(
13151315
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
1316-
ur_usm_migration_flags_t /*Flags*/, uint32_t NumSyncPointsInWaitList,
1316+
ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList,
13171317
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
13181318
uint32_t /*NumEventsInWaitList*/,
13191319
const ur_event_handle_t * /*EventWaitList*/,
@@ -1327,6 +1327,17 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
13271327
UR_COMMAND_USM_PREFETCH, CommandBuffer,
13281328
CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList,
13291329
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
1330+
switch (Flags) {
1331+
case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE:
1332+
break;
1333+
case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST:
1334+
UR_LOG(WARN, "commandBufferAppendUSMPrefetch: L0 does not support prefetch "
1335+
"to host yet");
1336+
break;
1337+
default:
1338+
UR_LOG(ERR, "commandBufferAppendUSMPrefetch: invalid USM migration flag");
1339+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
1340+
}
13301341

13311342
if (!ZeEventList.empty()) {
13321343
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
@@ -1335,9 +1346,11 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp(
13351346
}
13361347

13371348
// Add the prefetch command to the command-buffer.
1338-
// Note that L0 does not handle migration flags.
1339-
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
1340-
(CommandBuffer->ZeComputeCommandList, Mem, Size));
1349+
// TODO Support migration flags after L0 backend support is added.
1350+
if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) {
1351+
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
1352+
(CommandBuffer->ZeComputeCommandList, Mem, Size));
1353+
}
13411354

13421355
if (!CommandBuffer->IsInOrderCmdList) {
13431356
// Level Zero does not have a completion "event" with the prefetch API,

0 commit comments

Comments
 (0)