Skip to content

Commit 1f07f7c

Browse files
authored
[flang][cuda] Add support for allocate with device source (#171743)
Add support for allocate statement with a source that is a device variable.
1 parent 76ae530 commit 1f07f7c

File tree

10 files changed

+63
-36
lines changed

10 files changed

+63
-36
lines changed

flang-rt/lib/cuda/allocatable.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,34 @@ int RTDEF(CUFAllocatableAllocate)(Descriptor &desc, int64_t *stream,
5757

5858
int RTDEF(CUFAllocatableAllocateSource)(Descriptor &alloc,
5959
const Descriptor &source, int64_t *stream, bool *pinned, bool hasStat,
60-
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
60+
const Descriptor *errMsg, const char *sourceFile, int sourceLine,
61+
bool sourceIsDevice) {
6162
int stat{RTNAME(CUFAllocatableAllocate)(
6263
alloc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
6364
if (stat == StatOk) {
6465
Terminator terminator{sourceFile, sourceLine};
65-
Fortran::runtime::DoFromSourceAssign(
66-
alloc, source, terminator, &MemmoveHostToDevice);
66+
Fortran::runtime::DoFromSourceAssign(alloc, source, terminator,
67+
sourceIsDevice ? &MemmoveDeviceToHost : &MemmoveHostToDevice);
6768
}
6869
return stat;
6970
}
7071

7172
int RTDEF(CUFAllocatableAllocateSourceSync)(Descriptor &alloc,
7273
const Descriptor &source, int64_t *stream, bool *pinned, bool hasStat,
73-
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
74-
int stat{RTNAME(CUFAllocatableAllocateSync)(
75-
alloc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
74+
const Descriptor *errMsg, const char *sourceFile, int sourceLine,
75+
bool sourceIsDevice) {
76+
int stat;
77+
if (sourceIsDevice) {
78+
stat = RTNAME(CUFAllocatableAllocate)(
79+
alloc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine);
80+
} else {
81+
stat = RTNAME(CUFAllocatableAllocateSync)(
82+
alloc, stream, pinned, hasStat, errMsg, sourceFile, sourceLine);
83+
}
7684
if (stat == StatOk) {
7785
Terminator terminator{sourceFile, sourceLine};
78-
Fortran::runtime::DoFromSourceAssign(
79-
alloc, source, terminator, &MemmoveHostToDevice);
86+
Fortran::runtime::DoFromSourceAssign(alloc, source, terminator,
87+
sourceIsDevice ? &MemmoveDeviceToHost : &MemmoveHostToDevice);
8088
}
8189
return stat;
8290
}

flang-rt/lib/cuda/pointer.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,26 +56,28 @@ int RTDEF(CUFPointerAllocateSync)(Descriptor &desc, int64_t *stream,
5656

5757
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
5858
const Descriptor &source, int64_t *stream, bool *pinned, bool hasStat,
59-
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
59+
const Descriptor *errMsg, const char *sourceFile, int sourceLine,
60+
bool sourceIsDevice) {
6061
int stat{RTNAME(CUFPointerAllocate)(
6162
pointer, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
6263
if (stat == StatOk) {
6364
Terminator terminator{sourceFile, sourceLine};
64-
Fortran::runtime::DoFromSourceAssign(
65-
pointer, source, terminator, &MemmoveHostToDevice);
65+
Fortran::runtime::DoFromSourceAssign(pointer, source, terminator,
66+
sourceIsDevice ? &MemmoveDeviceToHost : &MemmoveHostToDevice);
6667
}
6768
return stat;
6869
}
6970

7071
int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
7172
const Descriptor &source, int64_t *stream, bool *pinned, bool hasStat,
72-
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
73+
const Descriptor *errMsg, const char *sourceFile, int sourceLine,
74+
bool sourceIsDevice) {
7375
int stat{RTNAME(CUFPointerAllocateSync)(
7476
pointer, stream, pinned, hasStat, errMsg, sourceFile, sourceLine)};
7577
if (stat == StatOk) {
7678
Terminator terminator{sourceFile, sourceLine};
77-
Fortran::runtime::DoFromSourceAssign(
78-
pointer, source, terminator, &MemmoveHostToDevice);
79+
Fortran::runtime::DoFromSourceAssign(pointer, source, terminator,
80+
sourceIsDevice ? &MemmoveDeviceToHost : &MemmoveHostToDevice);
7981
}
8082
return stat;
8183
}

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def cuf_AllocateOp : cuf_Op<"allocate", [AttrSizedOperandSegments,
100100
Optional<fir_ReferenceType>:$stream,
101101
Arg<Optional<AnyRefOrBoxType>, "", [MemWrite]>:$pinned,
102102
Arg<Optional<AnyRefOrBoxType>, "", [MemRead]>:$source,
103-
cuf_DataAttributeAttr:$data_attr, UnitAttr:$hasStat,
104-
UnitAttr:$hasDoubleDescriptor, UnitAttr:$pointer);
103+
OptionalAttr<cuf_DataAttributeAttr>:$data_attr, UnitAttr:$hasStat,
104+
UnitAttr:$hasDoubleDescriptor, UnitAttr:$pointer,
105+
UnitAttr:$device_source);
105106

106107
let results = (outs AnyIntegerType:$stat);
107108

flang/include/flang/Runtime/CUDA/allocatable.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ int RTDECL(CUFAllocatableAllocateSync)(Descriptor &, int64_t *stream = nullptr,
3434
int RTDEF(CUFAllocatableAllocateSource)(Descriptor &alloc,
3535
const Descriptor &source, int64_t *stream = nullptr, bool *pinned = nullptr,
3636
bool hasStat = false, const Descriptor *errMsg = nullptr,
37-
const char *sourceFile = nullptr, int sourceLine = 0);
37+
const char *sourceFile = nullptr, int sourceLine = 0,
38+
bool sourceIsDevice = false);
3839

3940
/// Perform allocation of the descriptor with synchronization of it when
4041
/// necessary. Assign data from source.
4142
int RTDEF(CUFAllocatableAllocateSourceSync)(Descriptor &alloc,
4243
const Descriptor &source, int64_t *stream = nullptr, bool *pinned = nullptr,
4344
bool hasStat = false, const Descriptor *errMsg = nullptr,
44-
const char *sourceFile = nullptr, int sourceLine = 0);
45+
const char *sourceFile = nullptr, int sourceLine = 0,
46+
bool sourceIsDevice = false);
4547

4648
/// Perform deallocation of the descriptor with synchronization of it when
4749
/// necessary.

flang/include/flang/Runtime/CUDA/pointer.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ int RTDECL(CUFPointerAllocateSync)(Descriptor &, int64_t *stream = nullptr,
3434
int RTDEF(CUFPointerAllocateSource)(Descriptor &pointer,
3535
const Descriptor &source, int64_t *stream = nullptr, bool *pinned = nullptr,
3636
bool hasStat = false, const Descriptor *errMsg = nullptr,
37-
const char *sourceFile = nullptr, int sourceLine = 0);
37+
const char *sourceFile = nullptr, int sourceLine = 0,
38+
bool sourceIsDevice = false);
3839

3940
/// Perform allocation of the descriptor with synchronization of it when
4041
/// necessary. Assign data from source.
4142
int RTDEF(CUFPointerAllocateSourceSync)(Descriptor &pointer,
4243
const Descriptor &source, int64_t *stream = nullptr, bool *pinned = nullptr,
4344
bool hasStat = false, const Descriptor *errMsg = nullptr,
44-
const char *sourceFile = nullptr, int sourceLine = 0);
45+
const char *sourceFile = nullptr, int sourceLine = 0,
46+
bool sourceIsDevice = false);
4547

4648
} // extern "C"
4749

flang/lib/Lower/Allocatable.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,9 +629,10 @@ class AllocateStmtHelper {
629629
unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol());
630630
fir::ExtendedValue exv = isSource ? sourceExv : moldExv;
631631

632+
bool sourceIsDevice = false;
632633
if (const Fortran::semantics::Symbol *sym{GetLastSymbol(sourceExpr)})
633634
if (Fortran::semantics::IsCUDADevice(*sym))
634-
TODO(loc, "CUDA Fortran: allocate with device source");
635+
sourceIsDevice = true;
635636

636637
// Generate a sequence of runtime calls.
637638
errorManager.genStatCheck(builder, loc);
@@ -651,7 +652,7 @@ class AllocateStmtHelper {
651652
genSetDeferredLengthParameters(alloc, box);
652653
genAllocateObjectBounds(alloc, box);
653654
mlir::Value stat;
654-
if (Fortran::semantics::HasCUDAAttr(alloc.getSymbol())) {
655+
if (Fortran::semantics::HasCUDAAttr(alloc.getSymbol()) || sourceIsDevice) {
655656
stat =
656657
genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol());
657658
} else {
@@ -798,13 +799,19 @@ class AllocateStmtHelper {
798799
// Keep return type the same as a standard AllocatableAllocate call.
799800
mlir::Type retTy = fir::runtime::getModel<int>()(builder.getContext());
800801

802+
bool isSourceDevice = false;
803+
if (const Fortran::semantics::Symbol *sym{GetLastSymbol(sourceExpr)})
804+
if (Fortran::semantics::IsCUDADevice(*sym))
805+
isSourceDevice = true;
806+
801807
bool doubleDescriptors = Fortran::lower::hasDoubleDescriptor(box.getAddr());
802808
return cuf::AllocateOp::create(
803809
builder, loc, retTy, box.getAddr(), errmsg, stream, pinned,
804810
source, cudaAttr,
805811
errorManager.hasStatSpec() ? builder.getUnitAttr() : nullptr,
806812
doubleDescriptors ? builder.getUnitAttr() : nullptr,
807-
box.isPointer() ? builder.getUnitAttr() : nullptr)
813+
box.isPointer() ? builder.getUnitAttr() : nullptr,
814+
isSourceDevice ? builder.getUnitAttr() : nullptr)
808815
.getResult();
809816
}
810817

flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ static mlir::LogicalResult convertOpToCall(OpTy op,
9999

100100
mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
101101
: builder.createBool(loc, false);
102-
103102
mlir::Value errmsg;
104103
if (op.getErrmsg()) {
105104
errmsg = op.getErrmsg();
@@ -116,12 +115,15 @@ static mlir::LogicalResult convertOpToCall(OpTy op,
116115
loc, fir::ReferenceType::get(
117116
mlir::IntegerType::get(op.getContext(), 1)));
118117
if (op.getSource()) {
118+
mlir::Value isDeviceSource = op.getDeviceSource()
119+
? builder.createBool(loc, true)
120+
: builder.createBool(loc, false);
119121
mlir::Value stream =
120122
op.getStream() ? op.getStream()
121123
: builder.createNullConstant(loc, fTy.getInput(2));
122124
args = fir::runtime::createArguments(
123125
builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
124-
hasStat, errmsg, sourceFile, sourceLine);
126+
hasStat, errmsg, sourceFile, sourceLine, isDeviceSource);
125127
} else {
126128
mlir::Value stream =
127129
op.getStream() ? op.getStream()

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,14 @@ func.func @_QPallocate_source() {
128128
%c1 = arith.constant 1 : index
129129
%c0 = arith.constant 0 : index
130130
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFallocate_sourceEa"}
131+
%devsource = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a", uniq_name = "_QFallocate_sourceEa"}
131132
%4 = fir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFallocate_sourceEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
132133
%5 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xf32>>> {bindc_name = "a_d", data_attr = #cuf.cuda<device>, uniq_name = "_QFallocate_sourceEa_d"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
133134
%7 = fir.declare %5 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFallocate_sourceEa_d"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
134135
%8 = fir.load %4 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
135136
%22 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>> source(%8 : !fir.box<!fir.heap<!fir.array<?x?xf32>>>) {data_attr = #cuf.cuda<device>} -> i32
137+
%9 = fir.load %devsource : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
138+
%23 = cuf.allocate %7 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>> source(%9 : !fir.box<!fir.heap<!fir.array<?x?xf32>>>) {device_source} -> i32
136139
return
137140
}
138141

@@ -142,8 +145,8 @@ func.func @_QPallocate_source() {
142145
// CHECK: %[[SOURCE:.*]] = fir.load %[[DECL_HOST]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
143146
// CHECK: %[[DEV_CONV:.*]] = fir.convert %[[DECL_DEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
144147
// CHECK: %[[SOURCE_CONV:.*]] = fir.convert %[[SOURCE]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.box<none>
145-
// CHECK: %{{.*}} = fir.call @_FortranACUFAllocatableAllocateSource(%[[DEV_CONV]], %[[SOURCE_CONV]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.ref<i64>, !fir.ref<i1>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
146-
148+
// CHECK: %{{.*}} = fir.call @_FortranACUFAllocatableAllocateSource(%[[DEV_CONV]], %[[SOURCE_CONV]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.ref<i64>, !fir.ref<i1>, i1, !fir.box<none>, !fir.ref<i8>, i32, i1) -> i32
149+
// CHECK: %{{.*}} = fir.call @_FortranACUFAllocatableAllocateSource(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %true{{.*}})
147150

148151
fir.global @_QMmod1Ea_d {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?x?xf32>>> {
149152
%c0 = arith.constant 0 : index

flang/test/Lower/CUDA/TODO/cuda-allocate-source-device.cuf

Lines changed: 0 additions & 9 deletions
This file was deleted.

flang/test/Lower/CUDA/cuda-allocatable.cuf

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,12 @@ end subroutine
261261
! CHECK: cuf.deallocate %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>, hasDoubleDescriptor} -> i32
262262
! CHECK: cuf.deallocate %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<managed>, hasDoubleDescriptor} -> i32
263263
! CHECK: cuf.deallocate %{{.*}} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<pinned>} -> i32
264+
265+
attributes(global) subroutine from_device_source()
266+
real, device, allocatable :: a(:)
267+
real, allocatable :: b(:)
268+
allocate(b, source=a)
269+
end subroutine
270+
271+
! CHECK-LABEL: func.func @_QPfrom_device_source()
272+
! CHECK: cuf.allocate{{.*}}device_source

0 commit comments

Comments
 (0)