From 8c5100ed693b9a7862b3ea7dd34436bb08a823d0 Mon Sep 17 00:00:00 2001 From: 1himan Date: Tue, 7 Apr 2026 15:02:45 +0530 Subject: [PATCH 1/2] Initial Changes --- mne/viz/evoked.py | 4 +- mne/viz/topo.py | 339 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 332 insertions(+), 11 deletions(-) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index a62d2379f03..dd384729da4 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1167,7 +1167,9 @@ def plot_evoked_topo( """Plot 2D topography of evoked responses. Clicking on the plot of an individual sensor opens a new figure showing - the evoked response for the selected sensor. + the evoked response for the selected sensor. After a figure is created, + hotkeys and on-figure controls can be used to adjust the y-limits and + switch between MEG channel views. Parameters ---------- diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 5c43d4de48e..03d46874ca0 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -182,9 +182,10 @@ def format_coord_multiaxis(x, y, ch_name=None): if layout is None: layout = find_layout(info) - if on_pick is not None: - callback = partial(_plot_topo_onpick, show_func=on_pick) - fig.canvas.mpl_connect("button_press_event", callback) + if on_pick is not None and not hasattr(fig, "_mne_topo_pick_cid"): + fig._mne_topo_pick_cid = fig.canvas.mpl_connect( + "button_press_event", partial(_plot_topo_onpick) + ) pos = layout.pos.copy() if layout_scale: @@ -220,7 +221,9 @@ def format_coord_multiaxis(x, y, ch_name=None): tick.set_visible(False) ax._mne_ch_name = name ax._mne_ch_idx = ch_idx + ax._mne_ch_type = channel_type(info, ch_idx) ax._mne_ax_face_color = axis_facecolor + ax._mne_onpick = on_pick ax.format_coord = partial(format_coord_multiaxis, ch_name=name) yield ax, ch_idx else: @@ -230,7 +233,9 @@ def format_coord_multiaxis(x, y, ch_name=None): data_lines=list(), _mne_ch_name=name, _mne_ch_idx=ch_idx, + _mne_ch_type=channel_type(info, ch_idx), _mne_ax_face_color=axis_facecolor, + _mne_onpick=on_pick, ) axs.append(ax) if not unified and legend: @@ -372,15 +377,22 @@ def _plot_topo( show_func(ax, ch_idx, tmin=tmin, tmax=tmax, vmin=vmin, vmax=vmax, ylim=ylim_) + if hasattr(fig, "_mne_topo_title"): + fig._mne_topo_title.remove() + del fig._mne_topo_title if title is not None: - plt.figtext(0.03, 0.95, title, color=font_color, fontsize=15, va="top") + fig._mne_topo_title = fig.text( + 0.03, 0.95, title, color=font_color, fontsize=15, va="top" + ) return fig -def _plot_topo_onpick(event, show_func): +def _plot_topo_onpick(event): """Onpick callback that shows a single channel in a new figure.""" orig_ax = event.inaxes + if orig_ax is None: + return fig = orig_ax.figure # If we are doing lasso select, allow it to handle the click instead. @@ -412,11 +424,15 @@ def _plot_topo_onpick(event, show_func): # neither old nor new mode return ch_idx = orig_ax._mne_ch_idx + show_func = orig_ax._mne_onpick + if show_func is None: + return face_color = orig_ax._mne_ax_face_color fig, ax = plt.subplots(1) plt.title(orig_ax._mne_ch_name) ax.set_facecolor(face_color) + ax._mne_ch_type = orig_ax._mne_ch_type # allow custom function to override parameters show_func(ax, ch_idx) @@ -573,14 +589,16 @@ def _plot_timeseries( import matplotlib.pyplot as plt picker_flag = False + lines = [] for data_, color_, times_ in zip(data, color, times): if not picker_flag: # use large tol for picker so we can click anywhere in the axes line = ax.plot(times_, data_[ch_idx], color=color_, picker=True)[0] line.set_pickradius(1e9) + lines.append(line) picker_flag = True else: - ax.plot(times_, data_[ch_idx], color=color_) + lines.append(ax.plot(times_, data_[ch_idx], color=color_)[0]) def _format_coord(x, y, labels, ax): """Create status string based on cursor coordinates.""" @@ -646,6 +664,10 @@ def _rm_cursor(event): plt.connect("motion_notify_event", _cursor_vline) plt.connect("axes_leave_event", _rm_cursor) + if isinstance(ylim, dict): + ylim = ylim.get(getattr(ax, "_mne_ch_type", None)) + if ylim is not None and not any(v is None for v in ylim): + ax.set_ylim(ylim) ymin, ymax = ax.get_ylim() # don't pass vline or hline here (this fxn doesn't do hvline_color): _setup_ax_spines(ax, [], tmin, tmax, ymin, ymax, hline=False) @@ -676,6 +698,10 @@ def _rm_cursor(event): for hline_ in hline: plt.axhline(hline_, color=hvline_color, linewidth=1.0, zorder=10) + if labels is not None: + legend_labels = [label if label else "Unknown" for label in labels] + ax.legend(lines, legend_labels, loc="best", prop={"size": 10}) + if colorbar: plt.colorbar() @@ -868,6 +894,266 @@ def _erfimage_imshow_unified( ) +def _scale_evoked_topo_ylim(ylim, factor): + """Scale topo y-limits while keeping zero fixed when possible.""" + ylim_scaled = dict() + for ch_type, (ymin, ymax) in ylim.items(): + if ymin <= 0 <= ymax: + ymin *= factor + ymax *= factor + elif ymin >= 0: + ymin = 0.0 if ymin == 0 else ymin * factor + ymax *= factor + else: + ymin *= factor + ymax = 0.0 if ymax == 0 else ymax * factor + ylim_scaled[ch_type] = [ymin, ymax] + return ylim_scaled + + +class _TopoInteractive: + """Figure-local controls for plot_evoked_topo.""" + + _ylim_step = 1.25 + + def __init__( + self, + *, + evoked, + layout, + layout_scale, + color, + border, + ylim, + scalings, + title, + proj, + vline, + fig_facecolor, + fig_background, + axis_facecolor, + font_color, + merge_channels, + legend, + noise_cov, + exclude, + select, + ): + import matplotlib.pyplot as plt + from matplotlib.patches import FancyBboxPatch + + info = evoked[0].info + self.evoked = evoked + self.layout = layout + self.layout_scale = layout_scale + self.color = color + self.border = border + self.manual_ylim = ylim + self.scalings = scalings + self.title = title + self.proj = proj + self.vline = vline + self.fig_facecolor = fig_facecolor + self.fig_background = fig_background + self.axis_facecolor = axis_facecolor + self.font_color = font_color + self.legend = legend + self.noise_cov = noise_cov + self.exclude = exclude + self.select = select + self.mode = "grad_rms" if merge_channels else "all" + self.ylim_scale = 1.0 + self.base_ylim = None + self.current_ylim = None + self._buttons = dict() + self._has_mag = len(pick_types(info, meg="mag", ref_meg=False, exclude=[])) > 0 + self._has_grad = ( + len(pick_types(info, meg="grad", ref_meg=False, exclude=[])) > 0 + ) + + self.fig = plt.figure(layout=None) + self.fig.set_facecolor(fig_facecolor) + self.axes = self.fig.add_axes([0.015, 0.085, 0.97, 0.89]) + self.fig._mne_topo_interactive = self + + dark_background = np.mean(_to_rgb(fig_facecolor)) < 0.5 + inactive_face = "#565656" if dark_background else "#e6e6e6" + active_face = "#1f6aa5" + edge_color = "white" if dark_background else "black" + default_text = "white" if dark_background else "black" + self._button_style = dict( + active_face=active_face, + inactive_face=inactive_face, + edge=edge_color, + active_text="white", + inactive_text=default_text, + ) + + specs = [("all", "All [a]", 0.09)] + if self._has_mag: + specs.append(("mag", "Mag [m]", 0.10)) + if self._has_grad: + specs.append(("grad", "Grad [g]", 0.11)) + specs.append(("grad_rms", "Join [j]", 0.11)) + specs.extend((("ylim_dec", "Y- [-]", 0.10), ("ylim_inc", "Y+ [+]", 0.10))) + + x0, y0, height, gap = 0.015, 0.02, 0.04, 0.008 + for name, label, width in specs: + patch = FancyBboxPatch( + (x0, y0), + width, + height, + boxstyle="round,pad=0.01", + linewidth=0.8, + edgecolor=edge_color, + facecolor=inactive_face, + transform=self.fig.transFigure, + zorder=10, + ) + text = self.fig.text( + x0 + width / 2.0, + y0 + height / 2.0, + label, + color=default_text, + ha="center", + va="center", + zorder=11, + ) + self.fig.add_artist(patch) + self._buttons[name] = Bunch( + bounds=(x0, y0, width, height), + patch=patch, + text=text, + ) + x0 += width + gap + + self.fig.canvas.mpl_connect("button_press_event", self._on_button_press) + self.fig.canvas.mpl_connect("key_press_event", self._on_keypress) + self._update_buttons() + + def _update_buttons(self): + for name, button in self._buttons.items(): + active = name == self.mode + button.patch.set_facecolor( + self._button_style["active_face"] + if active + else self._button_style["inactive_face"] + ) + button.text.set_color( + self._button_style["active_text"] + if active + else self._button_style["inactive_text"] + ) + + def _on_button_press(self, event): + if event.button != 1: + return + x, y = self.fig.transFigure.inverted().transform((event.x, event.y)) + for name, button in self._buttons.items(): + x0, y0, width, height = button.bounds + if x0 <= x <= x0 + width and y0 <= y <= y0 + height: + self._handle_action(name) + return + + def _on_keypress(self, event): + key = event.key + if key is None: + return + key = key.lower() + if key in ("-", "_"): + self._handle_action("ylim_dec") + elif key in ("+", "="): + self._handle_action("ylim_inc") + elif key == "a": + self._handle_action("all") + elif key == "m": + self._handle_action("mag") + elif key == "g": + self._handle_action("grad") + elif key == "j": + self._handle_action("grad_rms") + + def _handle_action(self, action): + if action == "ylim_dec": + if self.base_ylim is None: + return + self.ylim_scale /= self._ylim_step + elif action == "ylim_inc": + if self.base_ylim is None: + return + self.ylim_scale *= self._ylim_step + elif action == "mag" and not self._has_mag: + return + elif action in ("grad", "grad_rms") and not self._has_grad: + return + else: + if action == self.mode: + return + self.mode = action + self.ylim_scale = 1.0 + self.render() + + def _get_evoked_for_mode(self): + if self.mode == "mag": + evoked = [e.copy().pick(picks="mag") for e in self.evoked] + layout = self.layout + merge_channels = False + elif self.mode == "grad": + evoked = [e.copy().pick(picks="grad") for e in self.evoked] + layout = self.layout + merge_channels = False + elif self.mode == "grad_rms": + evoked = [e.copy().pick(picks="grad") for e in self.evoked] + layout = None + merge_channels = True + else: + evoked = self.evoked + layout = self.layout + merge_channels = False + return evoked, layout, merge_channels + + def render(self): + for ax in tuple(self.fig.axes): + if ax is not self.axes and ax.get_label() == "background": + ax.remove() + self.axes.clear() + evoked, layout, merge_channels = self._get_evoked_for_mode() + ylim = ( + self.manual_ylim + if self.ylim_scale == 1.0 + else _scale_evoked_topo_ylim(self.base_ylim, self.ylim_scale) + ) + _plot_evoked_topo( + evoked=evoked, + layout=layout, + layout_scale=self.layout_scale, + color=self.color, + border=self.border, + ylim=ylim, + scalings=self.scalings, + title=self.title, + proj=self.proj, + vline=self.vline, + fig_facecolor=self.fig_facecolor, + fig_background=self.fig_background, + axis_facecolor=self.axis_facecolor, + font_color=self.font_color, + merge_channels=merge_channels, + legend=self.legend, + axes=self.axes, + noise_cov=self.noise_cov, + exclude=self.exclude, + select=self.select, + show=False, + interactive=False, + ) + self.current_ylim = deepcopy(self.fig._mne_topo_ylims) + if self.ylim_scale == 1.0: + self.base_ylim = deepcopy(self.current_ylim) + self._update_buttons() + self.fig.canvas.draw_idle() + + def _plot_evoked_topo( evoked, layout=None, @@ -891,6 +1177,7 @@ def _plot_evoked_topo( exclude="bads", select=False, show=True, + interactive=True, ): """Plot 2D topography of evoked responses. @@ -970,6 +1257,9 @@ def _plot_evoked_topo( Show figure if True. .. versionadded:: 0.16.0 + interactive: bool + Whether to display the figure in interactive mode or not. + Defaults to True. Returns ------- @@ -984,6 +1274,32 @@ def _plot_evoked_topo( if type(evoked) not in (tuple, list): evoked = [evoked] + if interactive and axes is None: + topo = _TopoInteractive( + evoked=evoked, + layout=layout, + layout_scale=layout_scale, + color=color, + border=border, + ylim=ylim, + scalings=scalings, + title=title, + proj=proj, + vline=vline, + fig_facecolor=fig_facecolor, + fig_background=fig_background, + axis_facecolor=axis_facecolor, + font_color=font_color, + merge_channels=merge_channels, + legend=legend, + noise_cov=noise_cov, + exclude=exclude, + select=select, + ) + topo.render() + plt_show(show, fig=topo.fig) + return topo.fig + noise_cov = _check_cov(noise_cov, evoked[0].info) if noise_cov is not None: evoked = [whiten_evoked(e, noise_cov) for e in evoked] @@ -1147,6 +1463,9 @@ def _plot_evoked_topo( axes=axes, select=select, ) + legend_ax = axes if axes is not None else fig.axes[0] + fig._mne_topo_ylims = deepcopy(ylim_) + fig._mne_topo_ax = legend_ax add_background_image(fig, fig_background) @@ -1154,11 +1473,11 @@ def _plot_evoked_topo( legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] if select: - handles = fig.axes[0].lines[1 : len(evoked) + 1] + handles = legend_ax.lines[1 : len(evoked) + 1] else: - handles = fig.axes[0].lines[: len(evoked)] - legend = plt.legend( - labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} + handles = legend_ax.lines[: len(evoked)] + legend = legend_ax.legend( + handles=handles, labels=labels, loc=legend_loc, prop={"size": 10} ) legend.get_frame().set_facecolor(axis_facecolor) txts = legend.get_texts() From 34de7b46e0999f655f7128ebe44a62b0f9edfaf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 09:46:32 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/viz/topo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 03d46874ca0..f771e2d8b05 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -1266,8 +1266,6 @@ def _plot_evoked_topo( fig : instance of matplotlib.figure.Figure Images of evoked responses at sensor locations """ - import matplotlib.pyplot as plt - from ..channels.layout import _merge_ch_data, _pair_grad_sensors, find_layout from ..cov import whiten_evoked