diff --git a/src/execution/__tests__/mapAsyncIterable-test.ts b/src/execution/__tests__/mapAsyncIterable-test.ts index d1312bb94a..0e3a57ffe2 100644 --- a/src/execution/__tests__/mapAsyncIterable-test.ts +++ b/src/execution/__tests__/mapAsyncIterable-test.ts @@ -91,15 +91,10 @@ describe('mapAsyncIterable', () => { it('allows returning early from mapped async generator', async () => { async function* source() { - try { - yield 1; - /* c8 ignore next 3 */ - yield 2; - yield 3; // Shouldn't be reached. - } finally { - // eslint-disable-next-line no-unsafe-finally - return 'The End'; - } + yield 1; + /* c8 ignore next 3 */ + yield 2; + yield 3; // Shouldn't be reached. } const doubles = mapAsyncIterable(source(), (x) => x + x); @@ -108,8 +103,8 @@ describe('mapAsyncIterable', () => { expect(await doubles.next()).to.deep.equal({ value: 4, done: false }); // Early return - expect(await doubles.return('')).to.deep.equal({ - value: 'The End', + expect(await doubles.return()).to.deep.equal({ + value: undefined, done: true, }); @@ -147,7 +142,7 @@ describe('mapAsyncIterable', () => { expect(await doubles.next()).to.deep.equal({ value: 4, done: false }); // Early return - expect(await doubles.return(0)).to.deep.equal({ + expect(await doubles.return()).to.deep.equal({ value: undefined, done: true, }); @@ -155,15 +150,10 @@ describe('mapAsyncIterable', () => { it('passes through early return from async values', async () => { async function* source() { - try { - yield 'a'; - /* c8 ignore next 3 */ - yield 'b'; - yield 'c'; // Shouldn't be reached. - } finally { - yield 'Done'; - yield 'Last'; - } + yield 'a'; + /* c8 ignore next 3 */ + yield 'b'; + yield 'c'; // Shouldn't be reached. } const doubles = mapAsyncIterable(source(), (x) => x + x); @@ -173,14 +163,14 @@ describe('mapAsyncIterable', () => { // Early return expect(await doubles.return()).to.deep.equal({ - value: 'DoneDone', - done: false, + value: undefined, + done: true, }); - // Subsequent next calls may yield from finally block + // Subsequent next calls expect(await doubles.next()).to.deep.equal({ - value: 'LastLast', - done: false, + value: undefined, + done: true, }); expect(await doubles.next()).to.deep.equal({ value: undefined, @@ -260,14 +250,10 @@ describe('mapAsyncIterable', () => { it('passes through caught errors through async generators', async () => { async function* source() { - try { - yield 1; - /* c8 ignore next 2 */ - yield 2; - yield 3; // Shouldn't be reached. - } catch (e) { - yield e; - } + yield 1; + /* c8 ignore next 2 */ + yield 2; + yield 3; // Shouldn't be reached. } const doubles = mapAsyncIterable(source(), (x) => x + x); @@ -276,11 +262,9 @@ describe('mapAsyncIterable', () => { expect(await doubles.next()).to.deep.equal({ value: 4, done: false }); // Throw error - expect(await doubles.throw('Ouch')).to.deep.equal({ - value: 'OuchOuch', - done: false, - }); + await expectPromise(doubles.throw(new Error('Ouch'))).toRejectWith('Ouch'); + // Subsequent next calls expect(await doubles.next()).to.deep.equal({ value: undefined, done: true, diff --git a/src/execution/mapAsyncIterable.ts b/src/execution/mapAsyncIterable.ts index d54c3f49cb..d38bbcdd82 100644 --- a/src/execution/mapAsyncIterable.ts +++ b/src/execution/mapAsyncIterable.ts @@ -1,33 +1,18 @@ +import { isPromise } from '../jsutils/isPromise.js'; import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js'; +import { withCleanup } from './withCleanup.js'; + /** * Given an AsyncIterable and a callback function, return an AsyncIterator * which produces values mapped via calling the callback function. */ -export function mapAsyncIterable( - iterable: AsyncGenerator | AsyncIterable, +export function mapAsyncIterable( + iterable: AsyncGenerator | AsyncIterable, callback: (value: T) => PromiseOrValue, -): AsyncGenerator { - const iterator = iterable[Symbol.asyncIterator](); - - async function mapResult( - promise: Promise>, - ): Promise> { - const result = await promise; - if (result.done) { - return result; - } - - const value = result.value; - try { - return { value: await callback(value), done: false }; - } catch (error) { - await returnIgnoringErrors(); - throw error; - } - } - - async function returnIgnoringErrors(): Promise { +): AsyncGenerator { + return withCleanup(mapAsyncIterableImpl(iterable, callback), async () => { + const iterator = iterable[Symbol.asyncIterator](); if (typeof iterator.return === 'function') { try { await iterator.return(); /* c8 ignore start */ @@ -36,44 +21,19 @@ export function mapAsyncIterable( /* ignore error */ } /* c8 ignore stop */ } - } - - const asyncDispose: typeof Symbol.asyncDispose = - Symbol.asyncDispose /* c8 ignore start */ ?? - Symbol.for('Symbol.asyncDispose'); /* c8 ignore stop */ - - return { - async next() { - return mapResult(iterator.next()); - }, - async return(): Promise> { - // If iterator.return() does not exist, then type R must be undefined. - return typeof iterator.return === 'function' - ? mapResult(iterator.return()) - : { value: undefined as any, done: true }; - }, - async throw(error?: unknown) { - if (typeof iterator.throw === 'function') { - return mapResult(iterator.throw(error)); - } - - if (typeof iterator.return === 'function') { - await returnIgnoringErrors(); - } + }); +} - throw error; - }, - [Symbol.asyncIterator]() { - return this; - }, - async [asyncDispose]() { - await this.return(undefined as R); - if ( - typeof (iterable as AsyncGenerator)[asyncDispose] === - 'function' - ) { - await (iterable as AsyncGenerator)[asyncDispose](); - } - }, - }; +async function* mapAsyncIterableImpl( + iterable: AsyncGenerator | AsyncIterable, + mapFn: (value: T) => PromiseOrValue, +): AsyncGenerator { + for await (const value of iterable) { + const result = mapFn(value); + if (isPromise(result)) { + yield await result; + continue; + } + yield result; + } }