11#
22# Class for quick plotting of variables from models
33#
4+ from __future__ import annotations
45import os
56import numpy as np
67import pybamm
@@ -479,24 +480,24 @@ def reset_axis(self):
479480 ): # pragma: no cover
480481 raise ValueError (f"Axis limits cannot be NaN for variables '{ key } '" )
481482
482- def plot (self , t , dynamic = False ):
483+ def plot (self , t : float | list [ float ] , dynamic : bool = False ):
483484 """Produces a quick plot with the internal states at time t.
484485
485486 Parameters
486487 ----------
487- t : float
488- Dimensional time (in 'time_units') at which to plot.
488+ t : float or list of float
489+ Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
489490 dynamic : bool, optional
490491 Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
491492 If True, creates a dynamic plot with a slider.
492493 """
493494
494495 plt = import_optional_dependency ("matplotlib.pyplot" )
495496 gridspec = import_optional_dependency ("matplotlib.gridspec" )
496- cm = import_optional_dependency ("matplotlib" , "cm" )
497- colors = import_optional_dependency ("matplotlib" , "colors" )
498497
499- t_in_seconds = t * self .time_scaling_factor
498+ if not isinstance (t , list ):
499+ t = [t ]
500+
500501 self .fig = plt .figure (figsize = self .figsize )
501502
502503 self .gridspec = gridspec .GridSpec (self .n_rows , self .n_cols )
@@ -508,6 +509,11 @@ def plot(self, t, dynamic=False):
508509 # initialize empty handles, to be created only if the appropriate plots are made
509510 solution_handles = []
510511
512+ # Generate distinct colors for each time point
513+ time_colors = plt .cm .coolwarm (
514+ np .linspace (0 , 1 , len (t ))
515+ ) # Use a colormap for distinct colors
516+
511517 for k , (key , variable_lists ) in enumerate (self .variables .items ()):
512518 ax = self .fig .add_subplot (self .gridspec [k ])
513519 self .axes .add (key , ax )
@@ -518,19 +524,17 @@ def plot(self, t, dynamic=False):
518524 ax .xaxis .set_major_locator (plt .MaxNLocator (3 ))
519525 self .plots [key ] = defaultdict (dict )
520526 variable_handles = []
521- # Set labels for the first subplot only (avoid repetition)
527+
522528 if variable_lists [0 ][0 ].dimensions == 0 :
523- # 0D plot: plot as a function of time, indicating time t with a line
529+ # 0D plot: plot as a function of time, indicating multiple times with lines
524530 ax .set_xlabel (f"Time [{ self .time_unit } ]" )
525531 for i , variable_list in enumerate (variable_lists ):
526532 for j , variable in enumerate (variable_list ):
527- if len (variable_list ) == 1 :
528- # single variable -> use linestyle to differentiate model
529- linestyle = self .linestyles [i ]
530- else :
531- # multiple variables -> use linestyle to differentiate
532- # variables (color differentiates models)
533- linestyle = self .linestyles [j ]
533+ linestyle = (
534+ self .linestyles [i ]
535+ if len (variable_list ) == 1
536+ else self .linestyles [j ]
537+ )
534538 full_t = self .ts_seconds [i ]
535539 (self .plots [key ][i ][j ],) = ax .plot (
536540 full_t / self .time_scaling_factor ,
@@ -542,128 +546,104 @@ def plot(self, t, dynamic=False):
542546 solution_handles .append (self .plots [key ][i ][0 ])
543547 y_min , y_max = ax .get_ylim ()
544548 ax .set_ylim (y_min , y_max )
545- (self .time_lines [key ],) = ax .plot (
546- [
547- t_in_seconds / self .time_scaling_factor ,
548- t_in_seconds / self .time_scaling_factor ,
549- ],
550- [y_min , y_max ],
551- "k--" ,
552- lw = 1.5 ,
553- )
549+
550+ # Add vertical lines for each time in the list, using different colors for each time
551+ for idx , t_single in enumerate (t ):
552+ t_in_seconds = t_single * self .time_scaling_factor
553+ (self .time_lines [key ],) = ax .plot (
554+ [
555+ t_in_seconds / self .time_scaling_factor ,
556+ t_in_seconds / self .time_scaling_factor ,
557+ ],
558+ [y_min , y_max ],
559+ "--" , # Dashed lines
560+ lw = 1.5 ,
561+ color = time_colors [idx ], # Different color for each time
562+ label = f"t = { t_single :.2f} { self .time_unit } " ,
563+ )
564+ ax .legend ()
565+
554566 elif variable_lists [0 ][0 ].dimensions == 1 :
555- # 1D plot: plot as a function of x at time t
556- # Read dictionary of spatial variables
567+ # 1D plot: plot as a function of x at different times
557568 spatial_vars = self .spatial_variable_dict [key ]
558569 spatial_var_name = next (iter (spatial_vars .keys ()))
559- ax .set_xlabel (
560- f"{ spatial_var_name } [{ self .spatial_unit } ]" ,
561- )
562- for i , variable_list in enumerate (variable_lists ):
563- for j , variable in enumerate (variable_list ):
564- if len (variable_list ) == 1 :
565- # single variable -> use linestyle to differentiate model
566- linestyle = self .linestyles [i ]
567- else :
568- # multiple variables -> use linestyle to differentiate
569- # variables (color differentiates models)
570- linestyle = self .linestyles [j ]
571- (self .plots [key ][i ][j ],) = ax .plot (
572- self .first_spatial_variable [key ],
573- variable (t_in_seconds , ** spatial_vars ),
574- color = self .colors [i ],
575- linestyle = linestyle ,
576- zorder = 10 ,
577- )
578- variable_handles .append (self .plots [key ][0 ][j ])
579- solution_handles .append (self .plots [key ][i ][0 ])
580- # add lines for boundaries between subdomains
581- for boundary in variable_lists [0 ][0 ].internal_boundaries :
582- boundary_scaled = boundary * self .spatial_factor
583- ax .axvline (boundary_scaled , color = "0.5" , lw = 1 , zorder = 0 )
570+ ax .set_xlabel (f"{ spatial_var_name } [{ self .spatial_unit } ]" )
571+
572+ for idx , t_single in enumerate (t ):
573+ t_in_seconds = t_single * self .time_scaling_factor
574+
575+ for i , variable_list in enumerate (variable_lists ):
576+ for j , variable in enumerate (variable_list ):
577+ linestyle = (
578+ self .linestyles [i ]
579+ if len (variable_list ) == 1
580+ else self .linestyles [j ]
581+ )
582+ (self .plots [key ][i ][j ],) = ax .plot (
583+ self .first_spatial_variable [key ],
584+ variable (t_in_seconds , ** spatial_vars ),
585+ color = time_colors [idx ], # Different color for each time
586+ linestyle = linestyle ,
587+ label = f"t = { t_single :.2f} { self .time_unit } " , # Add time label
588+ zorder = 10 ,
589+ )
590+ variable_handles .append (self .plots [key ][0 ][j ])
591+ solution_handles .append (self .plots [key ][i ][0 ])
592+
593+ # Add a legend to indicate which plot corresponds to which time
594+ ax .legend ()
595+
584596 elif variable_lists [0 ][0 ].dimensions == 2 :
585- # Read dictionary of spatial variables
597+ # 2D plot: superimpose plots at different times
586598 spatial_vars = self .spatial_variable_dict [key ]
587- # there can only be one entry in the variable list
588599 variable = variable_lists [0 ][0 ]
589- # different order based on whether the domains are x-r, x-z or y-z, etc
590- if self .x_first_and_y_second [key ] is False :
591- x_name = list (spatial_vars .keys ())[1 ][0 ]
592- y_name = next (iter (spatial_vars .keys ()))[0 ]
593- x = self .second_spatial_variable [key ]
594- y = self .first_spatial_variable [key ]
595- var = variable (t_in_seconds , ** spatial_vars )
596- else :
597- x_name = next (iter (spatial_vars .keys ()))[0 ]
598- y_name = list (spatial_vars .keys ())[1 ][0 ]
600+
601+ for t_single in t :
602+ t_in_seconds = t_single * self .time_scaling_factor
599603 x = self .first_spatial_variable [key ]
600604 y = self .second_spatial_variable [key ]
601605 var = variable (t_in_seconds , ** spatial_vars ).T
602- ax .set_xlabel (f"{ x_name } [{ self .spatial_unit } ]" )
603- ax .set_ylabel (f"{ y_name } [{ self .spatial_unit } ]" )
604- vmin , vmax = self .variable_limits [key ]
605- # store the plot and the var data (for testing) as cant access
606- # z data from QuadMesh or QuadContourSet object
607- if self .is_y_z [key ] is True :
608- self .plots [key ][0 ][0 ] = ax .pcolormesh (
609- x ,
610- y ,
611- var ,
612- vmin = vmin ,
613- vmax = vmax ,
614- shading = self .shading ,
606+
607+ ax .set_xlabel (
608+ f"{ next (iter (spatial_vars .keys ()))[0 ]} [{ self .spatial_unit } ]"
615609 )
616- else :
617- self .plots [key ][0 ][0 ] = ax .contourf (
618- x , y , var , levels = 100 , vmin = vmin , vmax = vmax
610+ ax .set_ylabel (
611+ f"{ list (spatial_vars .keys ())[1 ][0 ]} [{ self .spatial_unit } ]"
619612 )
620- self .plots [key ][0 ][1 ] = var
621- if vmin is None and vmax is None :
622- vmin = ax_min (var )
623- vmax = ax_max (var )
624- self .colorbars [key ] = self .fig .colorbar (
625- cm .ScalarMappable (colors .Normalize (vmin = vmin , vmax = vmax )),
626- ax = ax ,
627- )
628- # Set either y label or legend entries
629- if len (key ) == 1 :
630- title = split_long_string (key [0 ])
631- ax .set_title (title , fontsize = "medium" )
632- else :
633- ax .legend (
634- variable_handles ,
635- [split_long_string (s , 6 ) for s in key ],
636- bbox_to_anchor = (0.5 , 1 ),
637- loc = "lower center" ,
638- )
613+ vmin , vmax = self .variable_limits [key ]
614+
615+ # Use contourf and colorbars to represent the values
616+ contour_plot = ax .contourf (
617+ x , y , var , levels = 100 , vmin = vmin , vmax = vmax , cmap = "coolwarm"
618+ )
619+ self .plots [key ][0 ][0 ] = contour_plot
620+ self .colorbars [key ] = self .fig .colorbar (contour_plot , ax = ax )
639621
640- # Set global legend
622+ self .plots [key ][0 ][1 ] = var
623+
624+ ax .set_title (f"t = { t_single :.2f} { self .time_unit } " )
625+
626+ # Set global legend if there are multiple models
641627 if len (self .labels ) > 1 :
642628 fig_legend = self .fig .legend (
643629 solution_handles , self .labels , loc = "lower right"
644630 )
645- # Get the position of the top of the legend in relative figure units
646- # There may be a better way ...
647- try :
648- legend_top_inches = fig_legend .get_window_extent (
649- renderer = self .fig .canvas .get_renderer ()
650- ).get_points ()[1 , 1 ]
651- fig_height_inches = (self .fig .get_size_inches () * self .fig .dpi )[1 ]
652- legend_top = legend_top_inches / fig_height_inches
653- except AttributeError : # pragma: no cover
654- # When testing the examples we set the matplotlib backend to "Template"
655- # which means that the above code doesn't work. Since this is just for
656- # that particular test we can just skip it
657- legend_top = 0
658631 else :
659- legend_top = 0
632+ fig_legend = None
660633
661- # Fix layout
634+ # Fix layout for sliders if dynamic
662635 if dynamic :
663636 slider_top = 0.05
664637 else :
665638 slider_top = 0
666- bottom = max (legend_top , slider_top )
639+ bottom = max (
640+ fig_legend .get_window_extent (
641+ renderer = self .fig .canvas .get_renderer ()
642+ ).get_points ()[1 , 1 ]
643+ if fig_legend
644+ else 0 ,
645+ slider_top ,
646+ )
667647 self .gridspec .tight_layout (self .fig , rect = [0 , bottom , 1 , 1 ])
668648
669649 def dynamic_plot (self , show_plot = True , step = None ):
0 commit comments