//===----------------------------------------------------------------------===// // // This source file is part of the Swift Async Algorithms open source project // // Copyright (c) 2022 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // //===----------------------------------------------------------------------===// import XCTest import AsyncAlgorithms final class TestBufferedByteIterator: XCTestCase { actor Isolated<T: Sendable> { var value: T init(_ value: T) { self.value = value } func update(_ value: T) async { self.value = value } } func test_immediately_empty() async throws { let reloaded = Isolated(false) var iterator = AsyncBufferedByteIterator(capacity: 3) { buffer in XCTAssertEqual(buffer.count, 3) await reloaded.update(true) return 0 } var wasReloaded = await reloaded.value XCTAssertFalse(wasReloaded) let byte = try await iterator.next() XCTAssertNil(byte) wasReloaded = await reloaded.value XCTAssertTrue(wasReloaded) } func test_one_pass() async throws { let reloaded = Isolated(0) var iterator = AsyncBufferedByteIterator(capacity: 3) { buffer in XCTAssertEqual(buffer.count, 3) let count = await reloaded.value await reloaded.update(count + 1) if count >= 1 { return 0 } buffer.copyBytes(from: [1, 2, 3]) return 3 } var reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 0) var byte = try await iterator.next() XCTAssertEqual(byte, 1) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 1) byte = try await iterator.next() XCTAssertEqual(byte, 2) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 1) byte = try await iterator.next() XCTAssertEqual(byte, 3) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 1) byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 2) byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 2) } func test_three_pass() async throws { let reloaded = Isolated(0) var iterator = AsyncBufferedByteIterator(capacity: 3) { buffer in XCTAssertEqual(buffer.count, 3) let count = await reloaded.value await reloaded.update(count + 1) if count >= 3 { return 0 } buffer.copyBytes(from: [1, 2, 3]) return 3 } var reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 0) for n in 1...3 { var byte = try await iterator.next() XCTAssertEqual(byte, 1) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) byte = try await iterator.next() XCTAssertEqual(byte, 2) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) byte = try await iterator.next() XCTAssertEqual(byte, 3) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) } var byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 4) byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 4) } func test_three_pass_throwing() async throws { let reloaded = Isolated(0) var iterator = AsyncBufferedByteIterator(capacity: 3) { buffer in XCTAssertEqual(buffer.count, 3) let count = await reloaded.value await reloaded.update(count + 1) if count >= 3 { return 0 } if count == 2 { throw Failure() } buffer.copyBytes(from: [1, 2, 3]) return 3 } var reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 0) for n in 1...3 { do { var byte = try await iterator.next() XCTAssertEqual(byte, 1) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) byte = try await iterator.next() XCTAssertEqual(byte, 2) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) byte = try await iterator.next() XCTAssertEqual(byte, 3) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, n) } catch { XCTAssertEqual(n, 3) break } } var byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 3) byte = try await iterator.next() XCTAssertNil(byte) reloadCount = await reloaded.value XCTAssertEqual(reloadCount, 3) } func test_cancellation() async { struct RepeatingBytes: AsyncSequence { typealias Element = UInt8 func makeAsyncIterator() -> AsyncBufferedByteIterator { AsyncBufferedByteIterator(capacity: 3) { buffer in buffer.copyBytes(from: [1, 2, 3]) return 3 } } } let finished = expectation(description: "finished") let iterated = expectation(description: "iterated") let task = Task { var firstIteration = false do { for try await _ in RepeatingBytes() { if !firstIteration { iterated.fulfill() firstIteration = true } } XCTFail("expected to throw a cancellation error") } catch { if error is CancellationError { finished.fulfill() } } } await fulfillment(of: [iterated], timeout: 1.0) // cancellation should ensure the loop finishes // without regards to the remaining underlying sequence task.cancel() await fulfillment(of: [finished], timeout: 1.0) } }