Skip to content

Commit 685c5cd

Browse files
committedJan 26, 2021
Adding tests for the various iree_task_t types.
1 parent cee8167 commit 685c5cd

21 files changed

+823
-70
lines changed
 

‎iree/hal/local/task_queue.c

+8-1
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,18 @@ static iree_status_t iree_hal_task_queue_submit_batch(
427427
// NOTE: if we fail from here on we must drop the retire_cmd arena.
428428
iree_status_t status = iree_ok_status();
429429

430+
// A fence we'll use to detect when the entire submission has completed.
431+
// TODO(benvanik): fold into the retire command.
432+
iree_task_fence_t* fence = NULL;
433+
status =
434+
iree_task_executor_acquire_fence(queue->executor, &queue->scope, &fence);
435+
iree_task_set_completion_task(&retire_cmd->task.header, &fence->header);
436+
430437
// Task to fork and wait for unsatisfied semaphore dependencies.
431438
// This is optional and only required if we have previous submissions still
432439
// in-flight - if the queue is empty then we can directly schedule the waits.
433440
iree_hal_task_queue_wait_cmd_t* wait_cmd = NULL;
434-
if (batch->wait_semaphores.count > 0) {
441+
if (iree_status_is_ok(status) && batch->wait_semaphores.count > 0) {
435442
status = iree_hal_task_queue_wait_cmd_allocate(
436443
&queue->scope, &batch->wait_semaphores, &retire_cmd->arena, &wait_cmd);
437444
}

‎iree/task/BUILD

+19
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,25 @@ cc_test(
124124
],
125125
)
126126

127+
cc_test(
128+
name = "task_tests",
129+
srcs = [
130+
"task_test_barrier.cc",
131+
"task_test_call.cc",
132+
"task_test_dispatch.cc",
133+
"task_test_fence.cc",
134+
"task_test_nop.cc",
135+
"task_test_wait.cc",
136+
],
137+
deps = [
138+
":task",
139+
"//iree/base:api",
140+
"//iree/task/testing:task_test",
141+
"//iree/testing:gtest",
142+
"//iree/testing:gtest_main",
143+
],
144+
)
145+
127146
cc_test(
128147
name = "topology_test",
129148
srcs = ["topology_test.cc"],

‎iree/task/CMakeLists.txt

+18
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ iree_cc_test(
122122
iree::testing::gtest_main
123123
)
124124

125+
iree_cc_test(
126+
NAME
127+
task_tests
128+
SRCS
129+
"task_test_barrier.cc"
130+
"task_test_call.cc"
131+
"task_test_dispatch.cc"
132+
"task_test_fence.cc"
133+
"task_test_nop.cc"
134+
"task_test_wait.cc"
135+
DEPS
136+
::task
137+
iree::base::api
138+
iree::task::testing::task_test
139+
iree::testing::gtest
140+
iree::testing::gtest_main
141+
)
142+
125143
iree_cc_test(
126144
NAME
127145
topology_test

‎iree/task/executor.c

+27-22
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,18 @@ iree_status_t iree_task_executor_create(
8686
// executor and since we know the precise lifetime of them we can keep them
8787
// entirely within the system here.
8888
if (iree_status_is_ok(status)) {
89-
status = iree_task_pool_initialize(
90-
allocator, sizeof(iree_task_dispatch_slice_t),
91-
worker_count * IREE_TASK_EXECUTOR_INITIAL_SLICE_RESERVATION_PER_WORKER,
92-
&executor->slice_task_pool);
89+
status = iree_task_pool_initialize(allocator, sizeof(iree_task_fence_t), 8,
90+
&executor->fence_task_pool);
9391
}
9492
if (iree_status_is_ok(status)) {
9593
status = iree_task_pool_initialize(
96-
allocator, sizeof(iree_task_dispatch_shard_t),
97-
worker_count * IREE_TASK_EXECUTOR_INITIAL_SHARD_RESERVATION_PER_WORKER,
98-
&executor->shard_task_pool);
94+
allocator,
95+
iree_max(sizeof(iree_task_dispatch_shard_t),
96+
sizeof(iree_task_dispatch_slice_t)),
97+
worker_count *
98+
iree_max(IREE_TASK_EXECUTOR_INITIAL_SHARD_RESERVATION_PER_WORKER,
99+
IREE_TASK_EXECUTOR_INITIAL_SLICE_RESERVATION_PER_WORKER),
100+
&executor->dispatch_task_pool);
99101
}
100102

101103
// Bring up the workers; the threads will be created here but be suspended
@@ -169,8 +171,8 @@ static void iree_task_executor_destroy(iree_task_executor_t* executor) {
169171
iree_slim_mutex_deinitialize(&executor->coordinator_mutex);
170172
iree_atomic_task_slist_deinitialize(&executor->incoming_ready_slist);
171173
iree_atomic_task_slist_deinitialize(&executor->incoming_waiting_slist);
172-
iree_task_pool_deinitialize(&executor->slice_task_pool);
173-
iree_task_pool_deinitialize(&executor->shard_task_pool);
174+
iree_task_pool_deinitialize(&executor->fence_task_pool);
175+
iree_task_pool_deinitialize(&executor->dispatch_task_pool);
174176
iree_allocator_free(executor->allocator, executor);
175177

176178
IREE_TRACE_ZONE_END(z0);
@@ -188,6 +190,19 @@ void iree_task_executor_release(iree_task_executor_t* executor) {
188190
}
189191
}
190192

193+
iree_status_t iree_task_executor_acquire_fence(iree_task_executor_t* executor,
194+
iree_task_scope_t* scope,
195+
iree_task_fence_t** out_fence) {
196+
*out_fence = NULL;
197+
iree_task_fence_t* fence = NULL;
198+
IREE_RETURN_IF_ERROR(iree_task_pool_acquire(&executor->fence_task_pool,
199+
(iree_task_t**)&fence));
200+
iree_task_fence_initialize(scope, fence);
201+
fence->header.pool = &executor->fence_task_pool;
202+
*out_fence = fence;
203+
return iree_ok_status();
204+
}
205+
191206
// Schedules a generic task to a worker matching its affinity.
192207
// The task will be posted to the worker mailbox and available for the worker to
193208
// begin processing as soon as the |post_batch| is submitted.
@@ -262,11 +277,11 @@ void iree_task_executor_schedule_ready_tasks(
262277
} else {
263278
if (task->flags & IREE_TASK_FLAG_DISPATCH_SLICED) {
264279
iree_task_dispatch_issue_sliced((iree_task_dispatch_t*)task,
265-
&executor->slice_task_pool,
280+
&executor->dispatch_task_pool,
266281
pending_submission, post_batch);
267282
} else {
268283
iree_task_dispatch_issue_sharded((iree_task_dispatch_t*)task,
269-
&executor->shard_task_pool,
284+
&executor->dispatch_task_pool,
270285
pending_submission, post_batch);
271286
}
272287
}
@@ -520,17 +535,7 @@ static void iree_task_executor_wait_any_task(
520535
void iree_task_executor_coordinate(iree_task_executor_t* executor,
521536
iree_task_worker_t* current_worker,
522537
bool speculative) {
523-
if (speculative) {
524-
if (!iree_slim_mutex_try_lock(&executor->coordinator_mutex)) {
525-
// Another thread is already holding the coordination lock.
526-
// Return to the caller to wait for it to finish.
527-
// TODO(benvanik): spin here if it's likely we'll have work after the
528-
// other coordinator finishes - that way we don't enter the wait.
529-
return;
530-
}
531-
} else {
532-
iree_slim_mutex_lock(&executor->coordinator_mutex);
533-
}
538+
iree_slim_mutex_lock(&executor->coordinator_mutex);
534539
IREE_TRACE_ZONE_BEGIN(z0);
535540

536541
// We may be adding tasks/waiting/etc on each pass through coordination - to

‎iree/task/executor.h

+5
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ void iree_task_executor_retain(iree_task_executor_t* executor);
319319
// Releases the given |executor| from the caller.
320320
void iree_task_executor_release(iree_task_executor_t* executor);
321321

322+
// Acquires a fence for the given |scope| from the executor fence pool.
323+
iree_status_t iree_task_executor_acquire_fence(iree_task_executor_t* executor,
324+
iree_task_scope_t* scope,
325+
iree_task_fence_t** out_fence);
326+
322327
// TODO(benvanik): scheduling mode mutation, compute quota control, etc.
323328

324329
// Submits a batch of tasks for execution.

‎iree/task/executor_impl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ struct iree_task_executor_s {
5050
// Pools of transient dispatch tasks shared across all workers.
5151
// Depending on configuration the task pool may allocate after creation using
5252
// the allocator provided upon executor creation.
53-
iree_task_pool_t slice_task_pool;
54-
iree_task_pool_t shard_task_pool;
53+
iree_task_pool_t fence_task_pool;
54+
iree_task_pool_t dispatch_task_pool;
5555

5656
// A list of incoming tasks that are ready to execute immediately.
5757
// The list is LIFO and we require that task lists are reversed by the

‎iree/task/executor_test.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ TEST(ExecutorTest, Any) {
123123
},
124124
0),
125125
workgroup_size_1, workgroup_count_1, &dispatch1);
126-
// dispatch1.header.flags |= IREE_TASK_FLAG_DISPATCH_SLICED;
126+
dispatch1.header.flags |= IREE_TASK_FLAG_DISPATCH_SLICED;
127127

128128
//
129129
iree_task_call_t call1;
@@ -155,9 +155,9 @@ TEST(ExecutorTest, Any) {
155155
#endif
156156

157157
// fence
158-
iree_task_fence_t fence0;
159-
iree_task_fence_initialize(&scope_a, &fence0);
160-
iree_task_set_completion_task(&call1.header, &fence0.header);
158+
iree_task_fence_t* fence0 = NULL;
159+
IREE_CHECK_OK(iree_task_executor_acquire_fence(executor, &scope_a, &fence0));
160+
iree_task_set_completion_task(&call1.header, &fence0->header);
161161

162162
//
163163
iree_task_submission_t sub0;

‎iree/task/scope.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void iree_task_scope_fail(iree_task_scope_t* scope, iree_task_t* task,
116116

117117
bool iree_task_scope_is_idle(iree_task_scope_t* scope) {
118118
return iree_atomic_load_int32(&scope->pending_submissions,
119-
iree_memory_order_relaxed) == 0;
119+
iree_memory_order_acquire) == 0;
120120
}
121121

122122
iree_status_t iree_task_scope_wait_idle(iree_task_scope_t* scope,

‎iree/task/task.c

+23-12
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) {
100100
// TODO(benvanik): signal as error.
101101
// iree_task_fence_t* fence_task = (iree_task_fence_t*)task;
102102
iree_atomic_fetch_sub_int32(&task->scope->pending_submissions, 1,
103-
iree_memory_order_relaxed);
103+
iree_memory_order_release);
104104
break;
105105
}
106106
case IREE_TASK_TYPE_WAIT:
@@ -232,7 +232,7 @@ void iree_task_fence_initialize(iree_task_scope_t* scope,
232232
iree_task_fence_t* out_task) {
233233
iree_task_initialize(IREE_TASK_TYPE_FENCE, scope, &out_task->header);
234234
iree_atomic_fetch_add_int32(&scope->pending_submissions, 1,
235-
iree_memory_order_relaxed);
235+
iree_memory_order_release);
236236
}
237237

238238
void iree_task_fence_retire(iree_task_fence_t* task,
@@ -408,6 +408,14 @@ void iree_task_dispatch_issue_sliced(iree_task_dispatch_t* dispatch_task,
408408
memcpy(workgroup_count, dispatch_task->workgroup_count.value,
409409
sizeof(workgroup_count));
410410
}
411+
uint32_t total_workgroup_count =
412+
workgroup_count[0] * workgroup_count[1] * workgroup_count[2];
413+
if (total_workgroup_count == 0) {
414+
// No workgroups to execute - bail early.
415+
iree_task_dispatch_retire(dispatch_task, pending_submission);
416+
IREE_TRACE_ZONE_END(z0);
417+
return;
418+
}
411419

412420
#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
413421
char xyz_string[32];
@@ -449,12 +457,15 @@ void iree_task_dispatch_issue_sliced(iree_task_dispatch_t* dispatch_task,
449457
workgroup_base[1] = slice_y * tiles_per_slice_y;
450458
workgroup_base[2] = slice_z * tiles_per_slice_z;
451459
uint32_t workgroup_range[3];
452-
workgroup_range[0] = iree_min(
453-
workgroup_count[0], workgroup_base[0] + tiles_per_slice_x - 1);
454-
workgroup_range[1] = iree_min(
455-
workgroup_count[1], workgroup_base[1] + tiles_per_slice_y - 1);
456-
workgroup_range[2] = iree_min(
457-
workgroup_count[2], workgroup_base[2] + tiles_per_slice_z - 1);
460+
workgroup_range[0] = iree_min(workgroup_count[0],
461+
workgroup_base[0] + tiles_per_slice_x) -
462+
1;
463+
workgroup_range[1] = iree_min(workgroup_count[1],
464+
workgroup_base[1] + tiles_per_slice_y) -
465+
1;
466+
workgroup_range[2] = iree_min(workgroup_count[2],
467+
workgroup_base[2] + tiles_per_slice_z) -
468+
1;
458469

459470
// Allocate and initialize the slice.
460471
iree_task_dispatch_slice_t* slice_task =
@@ -789,10 +800,10 @@ iree_status_t iree_task_dispatch_shard_execute(
789800
// TODO(benvanik): faster math here, especially knowing we pull off N
790801
// sequential indices per reservation.
791802
uint32_t tile_i = tile_index;
792-
tile_context.workgroup_xyz[0] = tile_i % (workgroup_count_x + 1);
793-
tile_i /= (workgroup_count_x + 1);
794-
tile_context.workgroup_xyz[1] = tile_i % (workgroup_count_y + 1);
795-
tile_i /= (workgroup_count_y + 1);
803+
tile_context.workgroup_xyz[0] = tile_i % workgroup_count_x;
804+
tile_i /= workgroup_count_x;
805+
tile_context.workgroup_xyz[1] = tile_i % workgroup_count_y;
806+
tile_i /= workgroup_count_y;
796807
tile_context.workgroup_xyz[2] = tile_i;
797808

798809
IREE_TRACE_ZONE_BEGIN_NAMED(z_tile,

0 commit comments

Comments
 (0)