Skip to content

Commit cca2d09

Browse files
committed
io: Correctly align async closure contexts
This fixes package fetching on Windows. Previously, `Async/GroupClosure` allocations were only aligned for the closure struct type, which resulted in panics when `context_alignment` (or `result_alignment` for that matter) had a greater alignment.
1 parent 5f13922 commit cca2d09

File tree

2 files changed

+165
-83
lines changed

2 files changed

+165
-83
lines changed

lib/std/Io/Threaded.zig

Lines changed: 106 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ const AsyncClosure = struct {
389389
select_condition: ?*ResetEvent,
390390
context_alignment: std.mem.Alignment,
391391
result_offset: usize,
392+
alloc_len: usize,
392393

393394
const done_reset_event: *ResetEvent = @ptrFromInt(@alignOf(ResetEvent));
394395

@@ -425,18 +426,59 @@ const AsyncClosure = struct {
425426

426427
fn contextPointer(ac: *AsyncClosure) [*]u8 {
427428
const base: [*]u8 = @ptrCast(ac);
428-
return base + ac.context_alignment.forward(@sizeOf(AsyncClosure));
429+
const context_offset = ac.context_alignment.forward(@intFromPtr(ac) + @sizeOf(AsyncClosure)) - @intFromPtr(ac);
430+
return base + context_offset;
431+
}
432+
433+
fn init(
434+
gpa: Allocator,
435+
mode: enum { async, concurrent },
436+
result_len: usize,
437+
result_alignment: std.mem.Alignment,
438+
context: []const u8,
439+
context_alignment: std.mem.Alignment,
440+
func: *const fn (context: *const anyopaque, result: *anyopaque) void,
441+
) Allocator.Error!*AsyncClosure {
442+
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(AsyncClosure);
443+
const worst_case_context_offset = context_alignment.forward(@sizeOf(AsyncClosure) + max_context_misalignment);
444+
const worst_case_result_offset = result_alignment.forward(worst_case_context_offset + context.len);
445+
const alloc_len = worst_case_result_offset + result_len;
446+
447+
const ac: *AsyncClosure = @ptrCast(@alignCast(try gpa.alignedAlloc(u8, .of(AsyncClosure), alloc_len)));
448+
errdefer comptime unreachable;
449+
450+
const actual_context_addr = context_alignment.forward(@intFromPtr(ac) + @sizeOf(AsyncClosure));
451+
const actual_result_addr = result_alignment.forward(actual_context_addr + context.len);
452+
const actual_result_offset = actual_result_addr - @intFromPtr(ac);
453+
ac.* = .{
454+
.closure = .{
455+
.cancel_tid = .none,
456+
.start = start,
457+
.is_concurrent = switch (mode) {
458+
.async => false,
459+
.concurrent => true,
460+
},
461+
},
462+
.func = func,
463+
.context_alignment = context_alignment,
464+
.result_offset = actual_result_offset,
465+
.alloc_len = alloc_len,
466+
.reset_event = .unset,
467+
.select_condition = null,
468+
};
469+
@memcpy(ac.contextPointer()[0..context.len], context);
470+
return ac;
429471
}
430472

431-
fn waitAndFree(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
473+
fn waitAndDeinit(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
432474
ac.reset_event.waitUncancelable();
433475
@memcpy(result, ac.resultPointer()[0..result.len]);
434-
free(ac, gpa, result.len);
476+
ac.deinit(gpa);
435477
}
436478

437-
fn free(ac: *AsyncClosure, gpa: Allocator, result_len: usize) void {
479+
fn deinit(ac: *AsyncClosure, gpa: Allocator) void {
438480
const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(ac);
439-
gpa.free(base[0 .. ac.result_offset + result_len]);
481+
gpa.free(base[0..ac.alloc_len]);
440482
}
441483
};
442484

@@ -452,44 +494,28 @@ fn async(
452494
start(context.ptr, result.ptr);
453495
return null;
454496
}
497+
455498
const t: *Threaded = @ptrCast(@alignCast(userdata));
456499
const cpu_count = t.cpu_count catch {
457500
return concurrent(userdata, result.len, result_alignment, context, context_alignment, start) catch {
458501
start(context.ptr, result.ptr);
459502
return null;
460503
};
461504
};
505+
462506
const gpa = t.allocator;
463-
const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
464-
const result_offset = result_alignment.forward(context_offset + context.len);
465-
const n = result_offset + result.len;
466-
const ac: *AsyncClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(AsyncClosure), n) catch {
507+
const ac = AsyncClosure.init(gpa, .async, result.len, result_alignment, context, context_alignment, start) catch {
467508
start(context.ptr, result.ptr);
468509
return null;
469-
}));
470-
471-
ac.* = .{
472-
.closure = .{
473-
.cancel_tid = .none,
474-
.start = AsyncClosure.start,
475-
.is_concurrent = false,
476-
},
477-
.func = start,
478-
.context_alignment = context_alignment,
479-
.result_offset = result_offset,
480-
.reset_event = .unset,
481-
.select_condition = null,
482510
};
483511

484-
@memcpy(ac.contextPointer()[0..context.len], context);
485-
486512
t.mutex.lock();
487513

488514
const thread_capacity = cpu_count - 1 + t.concurrent_count;
489515

490516
t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
491517
t.mutex.unlock();
492-
ac.free(gpa, result.len);
518+
ac.deinit(gpa);
493519
start(context.ptr, result.ptr);
494520
return null;
495521
};
@@ -501,7 +527,7 @@ fn async(
501527
if (t.threads.items.len == 0) {
502528
assert(t.run_queue.popFirst() == &ac.closure.node);
503529
t.mutex.unlock();
504-
ac.free(gpa, result.len);
530+
ac.deinit(gpa);
505531
start(context.ptr, result.ptr);
506532
return null;
507533
}
@@ -530,27 +556,11 @@ fn concurrent(
530556

531557
const t: *Threaded = @ptrCast(@alignCast(userdata));
532558
const cpu_count = t.cpu_count catch 1;
559+
533560
const gpa = t.allocator;
534-
const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
535-
const result_offset = result_alignment.forward(context_offset + context.len);
536-
const n = result_offset + result_len;
537-
const ac_bytes = gpa.alignedAlloc(u8, .of(AsyncClosure), n) catch
561+
const ac = AsyncClosure.init(gpa, .concurrent, result_len, result_alignment, context, context_alignment, start) catch {
538562
return error.ConcurrencyUnavailable;
539-
const ac: *AsyncClosure = @ptrCast(@alignCast(ac_bytes));
540-
541-
ac.* = .{
542-
.closure = .{
543-
.cancel_tid = .none,
544-
.start = AsyncClosure.start,
545-
.is_concurrent = true,
546-
},
547-
.func = start,
548-
.context_alignment = context_alignment,
549-
.result_offset = result_offset,
550-
.reset_event = .unset,
551-
.select_condition = null,
552563
};
553-
@memcpy(ac.contextPointer()[0..context.len], context);
554564

555565
t.mutex.lock();
556566

@@ -559,7 +569,7 @@ fn concurrent(
559569

560570
t.threads.ensureTotalCapacity(gpa, thread_capacity) catch {
561571
t.mutex.unlock();
562-
ac.free(gpa, result_len);
572+
ac.deinit(gpa);
563573
return error.ConcurrencyUnavailable;
564574
};
565575

@@ -569,7 +579,7 @@ fn concurrent(
569579
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
570580
assert(t.run_queue.popFirst() == &ac.closure.node);
571581
t.mutex.unlock();
572-
ac.free(gpa, result_len);
582+
ac.deinit(gpa);
573583
return error.ConcurrencyUnavailable;
574584
};
575585
t.threads.appendAssumeCapacity(thread);
@@ -588,7 +598,7 @@ const GroupClosure = struct {
588598
node: std.SinglyLinkedList.Node,
589599
func: *const fn (*Io.Group, context: *anyopaque) void,
590600
context_alignment: std.mem.Alignment,
591-
context_len: usize,
601+
alloc_len: usize,
592602

593603
fn start(closure: *Closure) void {
594604
const gc: *GroupClosure = @alignCast(@fieldParentPtr("closure", closure));
@@ -616,22 +626,48 @@ const GroupClosure = struct {
616626
if (prev_state == (sync_one_pending | sync_is_waiting)) reset_event.set();
617627
}
618628

619-
fn free(gc: *GroupClosure, gpa: Allocator) void {
620-
const base: [*]align(@alignOf(GroupClosure)) u8 = @ptrCast(gc);
621-
gpa.free(base[0..contextEnd(gc.context_alignment, gc.context_len)]);
622-
}
623-
624-
fn contextOffset(context_alignment: std.mem.Alignment) usize {
625-
return context_alignment.forward(@sizeOf(GroupClosure));
626-
}
627-
628-
fn contextEnd(context_alignment: std.mem.Alignment, context_len: usize) usize {
629-
return contextOffset(context_alignment) + context_len;
630-
}
631-
632629
fn contextPointer(gc: *GroupClosure) [*]u8 {
633630
const base: [*]u8 = @ptrCast(gc);
634-
return base + contextOffset(gc.context_alignment);
631+
const context_offset = gc.context_alignment.forward(@intFromPtr(gc) + @sizeOf(GroupClosure)) - @intFromPtr(gc);
632+
return base + context_offset;
633+
}
634+
635+
/// Does not initialize the `node` field.
636+
fn init(
637+
gpa: Allocator,
638+
t: *Threaded,
639+
group: *Io.Group,
640+
context: []const u8,
641+
context_alignment: std.mem.Alignment,
642+
func: *const fn (*Io.Group, context: *const anyopaque) void,
643+
) Allocator.Error!*GroupClosure {
644+
const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(GroupClosure);
645+
const worst_case_context_offset = context_alignment.forward(@sizeOf(GroupClosure) + max_context_misalignment);
646+
const alloc_len = worst_case_context_offset + context.len;
647+
648+
const gc: *GroupClosure = @ptrCast(@alignCast(try gpa.alignedAlloc(u8, .of(GroupClosure), alloc_len)));
649+
errdefer comptime unreachable;
650+
651+
gc.* = .{
652+
.closure = .{
653+
.cancel_tid = .none,
654+
.start = start,
655+
.is_concurrent = false,
656+
},
657+
.t = t,
658+
.group = group,
659+
.node = undefined,
660+
.func = func,
661+
.context_alignment = context_alignment,
662+
.alloc_len = alloc_len,
663+
};
664+
@memcpy(gc.contextPointer()[0..context.len], context);
665+
return gc;
666+
}
667+
668+
fn deinit(gc: *GroupClosure, gpa: Allocator) void {
669+
const base: [*]align(@alignOf(GroupClosure)) u8 = @ptrCast(gc);
670+
gpa.free(base[0..gc.alloc_len]);
635671
}
636672

637673
const sync_is_waiting: usize = 1 << 0;
@@ -646,27 +682,14 @@ fn groupAsync(
646682
start: *const fn (*Io.Group, context: *const anyopaque) void,
647683
) void {
648684
if (builtin.single_threaded) return start(group, context.ptr);
685+
649686
const t: *Threaded = @ptrCast(@alignCast(userdata));
650687
const cpu_count = t.cpu_count catch 1;
688+
651689
const gpa = t.allocator;
652-
const n = GroupClosure.contextEnd(context_alignment, context.len);
653-
const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
690+
const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch {
654691
return start(group, context.ptr);
655-
}));
656-
gc.* = .{
657-
.closure = .{
658-
.cancel_tid = .none,
659-
.start = GroupClosure.start,
660-
.is_concurrent = false,
661-
},
662-
.t = t,
663-
.group = group,
664-
.node = undefined,
665-
.func = start,
666-
.context_alignment = context_alignment,
667-
.context_len = context.len,
668692
};
669-
@memcpy(gc.contextPointer()[0..context.len], context);
670693

671694
t.mutex.lock();
672695

@@ -678,7 +701,7 @@ fn groupAsync(
678701

679702
t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
680703
t.mutex.unlock();
681-
gc.free(gpa);
704+
gc.deinit(gpa);
682705
return start(group, context.ptr);
683706
};
684707

@@ -688,7 +711,7 @@ fn groupAsync(
688711
const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
689712
assert(t.run_queue.popFirst() == &gc.closure.node);
690713
t.mutex.unlock();
691-
gc.free(gpa);
714+
gc.deinit(gpa);
692715
return start(group, context.ptr);
693716
};
694717
t.threads.appendAssumeCapacity(thread);
@@ -730,7 +753,7 @@ fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
730753
while (true) {
731754
const gc: *GroupClosure = @fieldParentPtr("node", node);
732755
const node_next = node.next;
733-
gc.free(gpa);
756+
gc.deinit(gpa);
734757
node = node_next orelse break;
735758
}
736759
}
@@ -761,7 +784,7 @@ fn groupCancel(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void
761784
while (true) {
762785
const gc: *GroupClosure = @fieldParentPtr("node", node);
763786
const node_next = node.next;
764-
gc.free(gpa);
787+
gc.deinit(gpa);
765788
node = node_next orelse break;
766789
}
767790
}
@@ -776,7 +799,7 @@ fn await(
776799
_ = result_alignment;
777800
const t: *Threaded = @ptrCast(@alignCast(userdata));
778801
const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
779-
closure.waitAndFree(t.allocator, result);
802+
closure.waitAndDeinit(t.allocator, result);
780803
}
781804

782805
fn cancel(
@@ -789,7 +812,7 @@ fn cancel(
789812
const t: *Threaded = @ptrCast(@alignCast(userdata));
790813
const ac: *AsyncClosure = @ptrCast(@alignCast(any_future));
791814
ac.closure.requestCancel();
792-
ac.waitAndFree(t.allocator, result);
815+
ac.waitAndDeinit(t.allocator, result);
793816
}
794817

795818
fn cancelRequested(userdata: ?*anyopaque) bool {

lib/std/Io/Threaded/test.zig

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,62 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" {
5656
getter.await(io);
5757
putter.await(io);
5858
}
59+
60+
const ByteArray256 = struct { x: [32]u8 align(32) };
61+
const ByteArray512 = struct { x: [64]u8 align(64) };
62+
63+
fn concatByteArrays(a: ByteArray256, b: ByteArray256) ByteArray512 {
64+
return .{ .x = a.x ++ b.x };
65+
}
66+
67+
test "async/concurrent context and result alignment" {
68+
var buffer: [2048]u8 align(@alignOf(ByteArray512)) = undefined;
69+
var fba: std.heap.FixedBufferAllocator = .init(&buffer);
70+
71+
var threaded: std.Io.Threaded = .init(fba.allocator());
72+
defer threaded.deinit();
73+
const io = threaded.io();
74+
75+
const a: ByteArray256 = .{ .x = @splat(2) };
76+
const b: ByteArray256 = .{ .x = @splat(3) };
77+
const expected: ByteArray512 = .{ .x = @as([32]u8, @splat(2)) ++ @as([32]u8, @splat(3)) };
78+
79+
{
80+
var future = io.async(concatByteArrays, .{ a, b });
81+
const result = future.await(io);
82+
try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
83+
}
84+
{
85+
var future = io.concurrent(concatByteArrays, .{ a, b }) catch |err| switch (err) {
86+
error.ConcurrencyUnavailable => {
87+
try testing.expect(builtin.single_threaded);
88+
return;
89+
},
90+
};
91+
const result = future.await(io);
92+
try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
93+
}
94+
}
95+
96+
fn concatByteArraysResultPtr(a: ByteArray256, b: ByteArray256, result: *ByteArray512) void {
97+
result.* = .{ .x = a.x ++ b.x };
98+
}
99+
100+
test "Group.async context alignment" {
101+
var buffer: [2048]u8 align(@alignOf(ByteArray512)) = undefined;
102+
var fba: std.heap.FixedBufferAllocator = .init(&buffer);
103+
104+
var threaded: std.Io.Threaded = .init(fba.allocator());
105+
defer threaded.deinit();
106+
const io = threaded.io();
107+
108+
const a: ByteArray256 = .{ .x = @splat(2) };
109+
const b: ByteArray256 = .{ .x = @splat(3) };
110+
const expected: ByteArray512 = .{ .x = @as([32]u8, @splat(2)) ++ @as([32]u8, @splat(3)) };
111+
112+
var group: std.Io.Group = .init;
113+
var result: ByteArray512 = undefined;
114+
group.async(io, concatByteArraysResultPtr, .{ a, b, &result });
115+
group.wait(io);
116+
try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
117+
}

0 commit comments

Comments
 (0)