|
| 1 | +package wit_async |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "runtime" |
| 6 | + "unsafe" |
| 7 | + |
| 8 | + "go.bytecodealliance.org/wit_runtime" |
| 9 | +) |
| 10 | + |
| 11 | +const EVENT_NONE uint32 = 0 |
| 12 | +const EVENT_SUBTASK uint32 = 1 |
| 13 | +const EVENT_STREAM_READ uint32 = 2 |
| 14 | +const EVENT_STREAM_WRITE uint32 = 3 |
| 15 | +const EVENT_FUTURE_READ uint32 = 4 |
| 16 | +const EVENT_FUTURE_WRITE uint32 = 5 |
| 17 | + |
| 18 | +const STATUS_STARTING uint32 = 0 |
| 19 | +const STATUS_STARTED uint32 = 1 |
| 20 | +const STATUS_RETURNED uint32 = 2 |
| 21 | + |
| 22 | +const CALLBACK_CODE_EXIT uint32 = 0 |
| 23 | +const CALLBACK_CODE_YIELD uint32 = 1 |
| 24 | +const CALLBACK_CODE_WAIT uint32 = 2 |
| 25 | + |
| 26 | +const RETURN_CODE_BLOCKED uint32 = 0xFFFFFFFF |
| 27 | +const RETURN_CODE_COMPLETED uint32 = 0 |
| 28 | +const RETURN_CODE_DROPPED uint32 = 1 |
| 29 | + |
| 30 | +type unit struct{} |
| 31 | + |
| 32 | +type taskState struct { |
| 33 | + channel chan unit |
| 34 | + waitableSet uint32 |
| 35 | + pending map[uint32]chan uint32 |
| 36 | + yielding chan unit |
| 37 | + pinner runtime.Pinner |
| 38 | +} |
| 39 | + |
| 40 | +var state *taskState = nil |
| 41 | + |
| 42 | +func Run(closure func()) uint32 { |
| 43 | + state = &taskState{ |
| 44 | + make(chan unit), |
| 45 | + 0, |
| 46 | + make(map[uint32]chan uint32), |
| 47 | + nil, |
| 48 | + runtime.Pinner{}, |
| 49 | + } |
| 50 | + state.pinner.Pin(state) |
| 51 | + |
| 52 | + defer func() { |
| 53 | + state = nil |
| 54 | + }() |
| 55 | + |
| 56 | + go closure() |
| 57 | + |
| 58 | + return callback(EVENT_NONE, 0, 0) |
| 59 | +} |
| 60 | + |
| 61 | +func Callback(event0, event1, event2 uint32) uint32 { |
| 62 | + state = (*taskState)(contextGet()) |
| 63 | + contextSet(nil) |
| 64 | + |
| 65 | + return callback(event0, event1, event2) |
| 66 | +} |
| 67 | + |
| 68 | +//go:linkname wasiOnIdle runtime.wasiOnIdle |
| 69 | +func wasiOnIdle(callback func() bool) |
| 70 | + |
| 71 | +func callback(event0, event1, event2 uint32) uint32 { |
| 72 | + yielding := state.yielding |
| 73 | + if state.yielding != nil { |
| 74 | + state.yielding = nil |
| 75 | + yielding <- unit{} |
| 76 | + } |
| 77 | + |
| 78 | + // Tell the Go scheduler to write to `state.channel` only after all |
| 79 | + // goroutines have either blocked or exited. This allows us to reliably |
| 80 | + // delay returning control to the host until there's truly nothing more |
| 81 | + // we can do in the guest. |
| 82 | + // |
| 83 | + // Note that this function is _not_ currently part of upstream Go; it |
| 84 | + // requires [this |
| 85 | + // patch](https://github.com/dicej/go/commit/40fc123d5bce6448fc4e4601fd33bad4250b36a5) |
| 86 | + wasiOnIdle(func() bool { |
| 87 | + state.channel <- unit{} |
| 88 | + return true |
| 89 | + }) |
| 90 | + defer wasiOnIdle(func() bool { |
| 91 | + return false |
| 92 | + }) |
| 93 | + |
| 94 | + for { |
| 95 | + switch event0 { |
| 96 | + case EVENT_NONE: |
| 97 | + |
| 98 | + case EVENT_SUBTASK: |
| 99 | + switch event2 { |
| 100 | + case STATUS_STARTING: |
| 101 | + panic(fmt.Sprintf("unexpected subtask status: %v", event2)) |
| 102 | + |
| 103 | + case STATUS_STARTED: |
| 104 | + |
| 105 | + case STATUS_RETURNED: |
| 106 | + waitableJoin(event1, 0) |
| 107 | + subtaskDrop(event1) |
| 108 | + channel := state.pending[event1] |
| 109 | + delete(state.pending, event1) |
| 110 | + channel <- event2 |
| 111 | + |
| 112 | + default: |
| 113 | + panic("todo") |
| 114 | + } |
| 115 | + |
| 116 | + case EVENT_STREAM_READ, EVENT_STREAM_WRITE, EVENT_FUTURE_READ, EVENT_FUTURE_WRITE: |
| 117 | + waitableJoin(event1, 0) |
| 118 | + channel := state.pending[event1] |
| 119 | + delete(state.pending, event1) |
| 120 | + channel <- event2 |
| 121 | + |
| 122 | + default: |
| 123 | + panic("todo") |
| 124 | + } |
| 125 | + |
| 126 | + // Block this goroutine until the scheduler wakes us up. |
| 127 | + (<-state.channel) |
| 128 | + |
| 129 | + if state.yielding != nil { |
| 130 | + contextSet(unsafe.Pointer(state)) |
| 131 | + if len(state.pending) == 0 { |
| 132 | + return CALLBACK_CODE_YIELD |
| 133 | + } else { |
| 134 | + if state.waitableSet == 0 { |
| 135 | + panic("unreachable") |
| 136 | + } |
| 137 | + event0, event1, event2 = func() (uint32, uint32, uint32) { |
| 138 | + pinner := runtime.Pinner{} |
| 139 | + defer pinner.Unpin() |
| 140 | + buffer := wit_runtime.Allocate(&pinner, 8, 4) |
| 141 | + event0 := waitableSetPoll(state.waitableSet, buffer) |
| 142 | + return event0, |
| 143 | + unsafe.Slice((*uint32)(buffer), 2)[0], |
| 144 | + unsafe.Slice((*uint32)(buffer), 2)[1] |
| 145 | + }() |
| 146 | + if event0 == EVENT_NONE { |
| 147 | + return CALLBACK_CODE_YIELD |
| 148 | + } |
| 149 | + } |
| 150 | + } else if len(state.pending) == 0 { |
| 151 | + state.pinner.Unpin() |
| 152 | + if state.waitableSet != 0 { |
| 153 | + waitableSetDrop(state.waitableSet) |
| 154 | + } |
| 155 | + return CALLBACK_CODE_EXIT |
| 156 | + } else { |
| 157 | + if state.waitableSet == 0 { |
| 158 | + panic("unreachable") |
| 159 | + } |
| 160 | + contextSet(unsafe.Pointer(state)) |
| 161 | + return CALLBACK_CODE_WAIT | (state.waitableSet << 4) |
| 162 | + } |
| 163 | + } |
| 164 | +} |
| 165 | + |
| 166 | +func SubtaskWait(status uint32) { |
| 167 | + subtask := status >> 4 |
| 168 | + status = status & 0xF |
| 169 | + |
| 170 | + switch status { |
| 171 | + case STATUS_STARTING, STATUS_STARTED: |
| 172 | + if state.waitableSet == 0 { |
| 173 | + state.waitableSet = waitableSetNew() |
| 174 | + } |
| 175 | + waitableJoin(subtask, state.waitableSet) |
| 176 | + channel := make(chan uint32) |
| 177 | + state.pending[subtask] = channel |
| 178 | + (<-channel) |
| 179 | + |
| 180 | + case STATUS_RETURNED: |
| 181 | + |
| 182 | + default: |
| 183 | + panic(fmt.Sprintf("unexpected subtask status: %v", status)) |
| 184 | + } |
| 185 | +} |
| 186 | + |
| 187 | +func FutureOrStreamWait(code uint32, handle int32) (uint32, uint32) { |
| 188 | + if code == RETURN_CODE_BLOCKED { |
| 189 | + if state.waitableSet == 0 { |
| 190 | + state.waitableSet = waitableSetNew() |
| 191 | + } |
| 192 | + waitableJoin(uint32(handle), state.waitableSet) |
| 193 | + channel := make(chan uint32) |
| 194 | + state.pending[uint32(handle)] = channel |
| 195 | + code = (<-channel) |
| 196 | + } |
| 197 | + |
| 198 | + count := code >> 4 |
| 199 | + code = code & 0xF |
| 200 | + |
| 201 | + return code, count |
| 202 | +} |
| 203 | + |
| 204 | +func Yield() { |
| 205 | + channel := make(chan unit) |
| 206 | + state.yielding = channel |
| 207 | + (<-channel) |
| 208 | +} |
| 209 | + |
| 210 | +//go:wasmimport $root [waitable-set-new] |
| 211 | +func waitableSetNew() uint32 |
| 212 | + |
| 213 | +//go:wasmimport $root [waitable-set-poll] |
| 214 | +func waitableSetPoll(waitableSet uint32, eventPayload unsafe.Pointer) uint32 |
| 215 | + |
| 216 | +//go:wasmimport $root [waitable-set-drop] |
| 217 | +func waitableSetDrop(waitableSet uint32) |
| 218 | + |
| 219 | +//go:wasmimport $root [waitable-join] |
| 220 | +func waitableJoin(waitable, waitableSet uint32) |
| 221 | + |
| 222 | +//go:wasmimport $root [context-get-0] |
| 223 | +func contextGet() unsafe.Pointer |
| 224 | + |
| 225 | +//go:wasmimport $root [context-set-0] |
| 226 | +func contextSet(value unsafe.Pointer) |
| 227 | + |
| 228 | +//go:wasmimport $root [subtask-drop] |
| 229 | +func subtaskDrop(subtask uint32) |
0 commit comments