Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def remove_unfinished(self) -> None:
if p not in self.data:
self._stack[p] = np.inf

def plot(self, n=None, tri_alpha=0):
def plot(self, n=None, tri_alpha=0.0):
r"""Plot the Learner2D's current state.

This plot function interpolates the data on a regular grid.
Expand Down
10 changes: 7 additions & 3 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,11 +911,12 @@ def simple(
npoints_goal: int | None = None,
end_time_goal: datetime | None = None,
duration_goal: timedelta | int | float | None = None,
points_per_ask: int = 1,
):
"""Run the learner until the goal is reached.

Requests a single point from the learner, evaluates
the function to be learned, and adds the point to the
Requests points from the learner, evaluates
the function to be learned, and adds the points to the
learner, until the goal is reached, blocking the current
thread.

Expand Down Expand Up @@ -946,6 +947,9 @@ def simple(
calculation. Stop when the current time is larger or equal than
``start_time + duration_goal``. ``duration_goal`` can be a number
indicating the number of seconds.
points_per_ask : int, optional
The number of points to ask for between every interpolation rerun. Defaults
to 1, which can introduce significant overhead on long runs.
"""
goal = _goal(
learner,
Expand All @@ -958,7 +962,7 @@ def simple(
)
assert goal is not None
while not goal(learner):
xs, _ = learner.ask(1)
xs, _ = learner.ask(points_per_ask)
for x in xs:
y = learner.function(x)
learner.tell(x, y)
Expand Down
62 changes: 62 additions & 0 deletions adaptive/tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,65 @@ def test_auto_goal():
simple(learner, auto_goal(duration=1e-2, learner=learner))
t_end = time.time()
assert t_end - t_start >= 1e-2


def test_simple_points_per_ask():
"""Test that the simple runner respects the points_per_ask parameter (PR #484)."""

def f(x):
return x**2

# Test with 1D learner asking for multiple points at once
learner1 = Learner1D(f, (-1, 1))
simple(learner1, npoints_goal=20, points_per_ask=5)
assert learner1.npoints >= 20

# Test with 2D learner
def f2d(xy):
x, y = xy
return x**2 + y**2

learner2 = Learner2D(f2d, ((-1, 1), (-1, 1)))
simple(learner2, npoints_goal=32, points_per_ask=8)
assert learner2.npoints >= 32

# Test that default behavior (points_per_ask=1) is preserved
learner3 = Learner1D(f, (-1, 1))
simple(learner3, npoints_goal=15)
assert learner3.npoints >= 15

# Test performance improvement: more points per ask = fewer ask calls
ask_count = 0
original_ask = Learner1D.ask

def counting_ask(self, n, tell_pending=True):
nonlocal ask_count
ask_count += 1
return original_ask(self, n, tell_pending)

# Monkey patch to count ask calls
Learner1D.ask = counting_ask

try:
# Test with points_per_ask=1 (default)
learner4 = Learner1D(f, (-1, 1))
ask_count = 0
simple(learner4, npoints_goal=10, points_per_ask=1)
ask_count_single = ask_count

# Test with points_per_ask=5
learner5 = Learner1D(f, (-1, 1))
ask_count = 0
simple(learner5, npoints_goal=10, points_per_ask=5)
ask_count_batch = ask_count

# When asking for 5 points at a time, we should have fewer ask calls
assert ask_count_batch < ask_count_single

# Both learners should have reached their goal
assert learner4.npoints >= 10
assert learner5.npoints >= 10

finally:
# Restore original method
Learner1D.ask = original_ask
Loading