Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions ucm/shared/trans/ascend/ascend_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,36 @@
* */
#include "ascend_buffer.h"
#include <acl/acl.h>
#include <limits>
#include <sys/mman.h>
#include "logger/logger.h"

namespace UC::Trans {

namespace {

constexpr std::uintptr_t HOST_REGISTER_PAGE_SIZE = 4096;

void FreeHostMemory(void* host)
{
auto ret = aclrtFreeHost(host);
if (ret != ACL_SUCCESS) { UC_ERROR("Failed to free host memory addr={} ret={}", host, ret); }
}

void* AlignUp(void* ptr, std::uintptr_t alignment)
{
const auto addr = reinterpret_cast<std::uintptr_t>(ptr);
return reinterpret_cast<void*>((addr + alignment - 1) / alignment * alignment);
}

void ReleaseHostPinnedMemory(void* registeredHost, void* allocatedHost)
{
Buffer::UnregisterHostBuffer(registeredHost);
FreeHostMemory(allocatedHost);
}

} // namespace

class HostHugePages : public std::enable_shared_from_this<HostHugePages> {
struct ConstructorKey {};
static constexpr auto HUGE_PAGE_SIZE = 2UL << 20;
Expand Down Expand Up @@ -124,6 +149,35 @@ std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer(size_t size)
return nullptr;
}

std::shared_ptr<void> Trans::AscendBuffer::MakeHostPinnedBuffer(size_t size, void** pDevice)
{
if (pDevice) { *pDevice = nullptr; }

constexpr auto kMaxSize = std::numeric_limits<size_t>::max();
if (size > kMaxSize - (HOST_REGISTER_PAGE_SIZE - 1)) { return nullptr; }

void* allocatedHost = nullptr;
const auto allocationSize = size + HOST_REGISTER_PAGE_SIZE - 1;
auto ret = aclrtMallocHost(&allocatedHost, allocationSize);
if (ret != ACL_SUCCESS) { return nullptr; }

void* host = AlignUp(allocatedHost, HOST_REGISTER_PAGE_SIZE);

void* device = nullptr;
auto status = Buffer::RegisterHostBuffer(host, size, &device);
if (status.Failure()) {
UC_ERROR("Failed to register host-pinned memory addr={} size={} status={}", host, size,
status);
FreeHostMemory(allocatedHost);
return nullptr;
}

if (pDevice) { *pDevice = device; }
return std::shared_ptr<void>(host, [allocatedHost](void* registeredHost) {
ReleaseHostPinnedMemory(registeredHost, allocatedHost);
});
}

std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer4DirectIo(size_t size)
{
try {
Expand Down
1 change: 1 addition & 0 deletions ucm/shared/trans/ascend/ascend_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AscendBuffer : public ReservedBuffer {
public:
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) override;
std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) override;
};

Expand Down
1 change: 1 addition & 0 deletions ucm/shared/trans/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Buffer {

virtual std::shared_ptr<void> MakeHostBuffer(size_t size) = 0;
virtual std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) = 0;
virtual std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) = 0;
Comment thread
yumingyue624 marked this conversation as resolved.
virtual Status MakeHostBuffers(size_t size, size_t number) = 0;
virtual std::shared_ptr<void> GetHostBuffer(size_t size) = 0;

Expand Down
6 changes: 6 additions & 0 deletions ucm/shared/trans/detail/reserved_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ class ReservedBuffer : public Buffer {
return this->MakeHostBuffer(size);
}

std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) override
{
if (pDevice) { *pDevice = nullptr; }
return this->MakeHostBuffer(size);
}

Status MakeHostBuffers(size_t size, size_t number) override
{
auto totalSize = size * number;
Expand Down
121 changes: 107 additions & 14 deletions ucm/transport/kv/asu/test/case/buffer_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ TEST_F(BufferManagerTest, InitWithZeroSlotNum)
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
}

TEST_F(BufferManagerTest, InitHostWithUnalignedSlotCapacity)
{
BufferManager mgr;
auto status = mgr.Init("test_buffer", MemoryType::HOST, 1000, 10);
ASSERT_TRUE(status.ok()) << status.message;
}

TEST_F(BufferManagerTest, InitDeviceWithUnalignedSlotCapacity)
{
BufferManager mgr;
auto status = mgr.Init("test_buffer", MemoryType::ASCEND_DEVICE, 1000, 10);
ASSERT_TRUE(status.ok()) << status.message;
}

TEST_F(BufferManagerTest, DoubleInit)
{
BufferManager mgr;
Expand Down Expand Up @@ -123,12 +137,13 @@ TEST_F(BufferManagerTest, SingleAllocateAndFree)
ScatterGatherEntry sge;
status = mgr.Allocate(64, sge);
ASSERT_TRUE(status.ok()) << status.message;
ASSERT_NE(sge.addr, 0);
ASSERT_NE(sge.local_addr, 0);
ASSERT_EQ(sge.length, 64);
ASSERT_EQ(sge.tokenId, 0);
ASSERT_NE(sge.slot_index, UINT32_MAX);
ASSERT_EQ(sge.memory_type, MemoryType::HOST);

auto* ptr = reinterpret_cast<void*>(sge.addr);
auto* ptr = reinterpret_cast<void*>(sge.local_addr);
std::memset(ptr, 0xAB, 64);

status = mgr.Free(sge.slot_index);
Expand All @@ -147,12 +162,12 @@ TEST_F(BufferManagerTest, MultipleAllocatesAndFrees)
for (int i = 0; i < kCount; ++i) {
status = mgr.Allocate(128, sges[i]);
ASSERT_TRUE(status.ok()) << "Failed at i=" << i << ": " << status.message;
ASSERT_NE(sges[i].addr, 0);
std::memset(reinterpret_cast<void*>(sges[i].addr), i, 128);
ASSERT_NE(sges[i].local_addr, 0);
std::memset(reinterpret_cast<void*>(sges[i].local_addr), i, 128);
}

for (int i = 0; i < kCount; ++i) {
auto* data = reinterpret_cast<unsigned char*>(sges[i].addr);
auto* data = reinterpret_cast<unsigned char*>(sges[i].local_addr);
for (int j = 0; j < 128; ++j) { ASSERT_EQ(data[j], static_cast<unsigned char>(i)); }
}

Expand Down Expand Up @@ -191,10 +206,61 @@ TEST_F(BufferManagerTest, AllocateFullSlotSize)
status = mgr.Allocate(1024, sge);
ASSERT_TRUE(status.ok()) << status.message;
ASSERT_EQ(sge.length, 1024);
}

TEST_F(BufferManagerTest, AllocateFull4160ByteSlotCapacity)
{
BufferManager mgr;
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
ASSERT_TRUE(status.ok());

std::memset(reinterpret_cast<void*>(sge.addr), 0xFF, 1024);
ScatterGatherEntry sge;
status = mgr.Allocate(4160, sge);
ASSERT_TRUE(status.ok()) << status.message;
ASSERT_EQ(sge.length, 4160);
}

mgr.Free(sge.slot_index);
TEST_F(BufferManagerTest, AllocateExceeds4160ByteSlotCapacity)
{
BufferManager mgr;
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
ASSERT_TRUE(status.ok());

ScatterGatherEntry sge;
status = mgr.Allocate(4161, sge);
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
}

TEST_F(BufferManagerTest, AllMemoryTypesUseAlignedSlotStride)
{
for (const auto type : {MemoryType::HOST, MemoryType::HOST_PINNED, MemoryType::ASCEND_DEVICE}) {
BufferManager mgr;
auto status = mgr.Init("test_buffer", type, 4160, 2);
ASSERT_TRUE(status.ok()) << status.message;

ScatterGatherEntry first;
ScatterGatherEntry second;
ASSERT_TRUE(mgr.Allocate(4160, first).ok());
ASSERT_TRUE(mgr.Allocate(4160, second).ok());
ASSERT_EQ(second.local_addr - first.local_addr, 4160);
ASSERT_EQ(second.device_addr - first.device_addr, 4160);
}
}

TEST_F(BufferManagerTest, FlagBufferCapacity71Uses128ByteStride)
{
BufferManager mgr;
auto status = mgr.Init("flag_buffer", MemoryType::HOST_PINNED, 71, 2);
ASSERT_TRUE(status.ok()) << status.message;

ScatterGatherEntry first;
ScatterGatherEntry second;
ASSERT_TRUE(mgr.Allocate(71, first).ok());
ASSERT_TRUE(mgr.Allocate(71, second).ok());
ASSERT_EQ(first.length, 71);
ASSERT_EQ(second.local_addr - first.local_addr, 128);
ASSERT_EQ(second.device_addr - first.device_addr, 128);
}

TEST_F(BufferManagerTest, ReuseAfterFree)
Expand All @@ -212,7 +278,7 @@ TEST_F(BufferManagerTest, ReuseAfterFree)
ScatterGatherEntry sge2;
status = mgr.Allocate(64, sge2);
ASSERT_TRUE(status.ok());
ASSERT_EQ(sge2.addr, sge1.addr);
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
ASSERT_EQ(sge2.slot_index, sge1.slot_index);

mgr.Free(sge2.slot_index);
Expand All @@ -233,7 +299,7 @@ TEST_F(BufferManagerTest, ConcurrentAllocateAndFree)
auto s = mgr.Allocate(64, sge);
ASSERT_TRUE(s.ok()) << "Thread " << thread_id << " op " << i << ": " << s.message;

std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 64);
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 64);

s = mgr.Free(sge.slot_index);
ASSERT_TRUE(s.ok()) << s.message;
Expand All @@ -260,10 +326,10 @@ TEST_F(BufferManagerTest, ConcurrentStressTest)
auto s = mgr.Allocate(128, sge);
ASSERT_TRUE(s.ok());

std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 128);
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 128);

for (int j = 0; j < 128; ++j) {
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.addr)[j], thread_id);
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.local_addr)[j], thread_id);
}

s = mgr.Free(sge.slot_index);
Expand All @@ -286,7 +352,7 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
status = mgr.Allocate(64, sge1);
ASSERT_TRUE(status.ok());

auto* ptr = reinterpret_cast<uint8_t*>(sge1.addr);
auto* ptr = reinterpret_cast<uint8_t*>(sge1.local_addr);
std::memset(ptr, 0xAB, 1024);

status = mgr.Free(sge1.slot_index);
Expand All @@ -295,10 +361,10 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
ScatterGatherEntry sge2;
status = mgr.Allocate(64, sge2);
ASSERT_TRUE(status.ok());
ASSERT_EQ(sge2.addr, sge1.addr);
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
ASSERT_EQ(sge2.slot_index, sge1.slot_index);

auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.addr);
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.local_addr);
for (size_t i = 0; i < 1024; ++i) {
ASSERT_EQ(ptr2[i], 0) << "byte " << i << " not zeroed after free";
}
Expand Down Expand Up @@ -395,6 +461,33 @@ TEST_F(BufferManagerTest, InitWithProviderRegistersMemory)
ASSERT_NE(provider.lastAddr, 0);
ASSERT_EQ(provider.lastSize, 1024 * 10);
ASSERT_EQ(mgr.GetTokenId(), 42);

ScatterGatherEntry sge;
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
ASSERT_EQ(sge.local_addr, sge.device_addr);
}

TEST_F(BufferManagerTest, HostPinnedRegistersDeviceAddress)
{
StubTransProvider provider;

BufferManager mgr;
auto status = mgr.Init("test_rdma_pinned", MemoryType::HOST_PINNED, 4096, 1, &provider);
ASSERT_TRUE(status.ok()) << status.message;
ASSERT_EQ(provider.registerCount, 1);
ASSERT_EQ(provider.lastMemType, TransProvider::MemType::MEM_DEVICE);

ScatterGatherEntry sge;
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
ASSERT_NE(sge.local_addr, 0);
ASSERT_NE(sge.device_addr, 0);
ASSERT_NE(sge.local_addr, sge.device_addr);
ASSERT_EQ(sge.local_addr % 4096, 0);
ASSERT_EQ(provider.lastAddr, sge.device_addr);

// The CPU writes through addr while HCOMM and remote RDMA use device_addr.
std::memset(reinterpret_cast<void*>(sge.local_addr), 0x5A, sge.length);
ASSERT_EQ(*reinterpret_cast<unsigned char*>(sge.local_addr), 0x5A);
}

TEST_F(BufferManagerTest, InitWithProviderAllocateReturnsTokenId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ struct TransportConfig {
bool preconnect{true};
bool bindCqPoller{true};

// Slot sizes are caller-visible capacities; BufferManager computes the
// aligned physical stride used for allocation and memory registration.
std::size_t sendBufferSlotSize{4160};
std::size_t sendBufferSlotNum{128};
std::size_t flagBufferSlotSize{128};
// Maximum memory required by a batch store/retrieve response flag buffer.
std::size_t flagBufferSlotSize{71};
Comment thread
yumingyue624 marked this conversation as resolved.
std::size_t flagBufferSlotNum{4096};
std::size_t asuBatchLoadIoNum{110};
std::size_t asuBatchStoreIoNum{110};
Expand Down
29 changes: 18 additions & 11 deletions ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,31 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers(
continue;
}

if (subBatchContext.flagBuffer.addr == 0 || subBatchContext.flagBuffer.length == 0) {
const auto subBatchStatus =
Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready");
if (subBatchContext.channel == nullptr ||
!IsTransportBufferReady(subBatchContext.sendSge) ||
!IsTransportBufferReady(subBatchContext.flagBuffer)) {
const auto subBatchStatus = Status::Error(StatusCode::NOT_INITIALIZED,
"sub-batch transport buffers are not ready");
UC_ERROR(
"Sub-batch flag buffer is not ready index={} cid={} flag_addr={} flag_length={}",
index, subBatchContext.cid, subBatchContext.flagBuffer.addr,
subBatchContext.flagBuffer.length);
"Sub-batch transport buffers are not ready index={} cid={} channel={} "
"send_local_addr={} send_device_addr={} send_length={} send_slot={} "
"flag_local_addr={} flag_device_addr={} flag_length={} flag_slot={}",
index, subBatchContext.cid, subBatchContext.channel != nullptr,
subBatchContext.sendSge.local_addr, subBatchContext.sendSge.device_addr,
subBatchContext.sendSge.length, subBatchContext.sendSge.slot_index,
subBatchContext.flagBuffer.local_addr, subBatchContext.flagBuffer.device_addr,
subBatchContext.flagBuffer.length, subBatchContext.flagBuffer.slot_index);
SetSubBatchSendFailed(subBatchContext, subBatchStatus);
if (status.ok()) { status = subBatchStatus; }
ReleaseSubBatchResources(subBatchContext);
continue;
}

ioBatches.push_back(
TransProvider::SendIoBatch{subBatchContext.channel->GetConnection(),
reinterpret_cast<void*>(subBatchContext.sendSge.addr),
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr),
subBatchContext.sendSge.length});
ioBatches.push_back(TransProvider::SendIoBatch{
subBatchContext.channel->GetConnection(),
reinterpret_cast<void*>(subBatchContext.sendSge.device_addr),
reinterpret_cast<void*>(subBatchContext.flagBuffer.device_addr),
subBatchContext.sendSge.length});
subBatchIndexes.emplace_back(index);
}

Expand Down
Loading
Loading