@@ -96,7 +96,9 @@ def test_generate_datasets():
9696 assert "target" not in additional_columns
9797
9898
99- def setup_test_data (model , freq , num_series , horizon = 5 , num_points = 100 , seed = 42 , include_additional = True ):
99+ def setup_test_data (
100+ model , freq , num_series , horizon = 5 , num_points = 100 , seed = 42 , include_additional = True
101+ ):
100102 """
101103 Setup test data for the given parameters.
102104
@@ -113,17 +115,21 @@ def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42,
113115 - Tuple containing primary, additional datasets and the operator configuration.
114116 """
115117 primary , additional , _ , _ = generate_datasets (
116- freq = freq , horizon = horizon , num_series = num_series , num_points = num_points , seed = seed
118+ freq = freq ,
119+ horizon = horizon ,
120+ num_series = num_series ,
121+ num_points = num_points ,
122+ seed = seed ,
117123 )
118124
119125 yaml_i = deepcopy (TEMPLATE_YAML )
120126 yaml_i ["spec" ]["historical_data" ].pop ("url" )
121127 yaml_i ["spec" ]["historical_data" ]["data" ] = primary
122128 yaml_i ["spec" ]["historical_data" ]["format" ] = "pandas"
123-
129+
124130 if include_additional :
125131 yaml_i ["spec" ]["additional_data" ] = {"data" : additional , "format" : "pandas" }
126-
132+
127133 yaml_i ["spec" ]["model" ] = model
128134 yaml_i ["spec" ]["target_column" ] = "target"
129135 yaml_i ["spec" ]["datetime_column" ]["name" ] = "ds"
@@ -177,10 +183,18 @@ def test_explanations_output_and_columns(model, freq, num_series):
177183 not (local_explanations == 0 ).all ().all ()
178184 ), "Local explanations contain only 0 values"
179185
180- additional_columns = additional .columns .tolist ()
186+ additional_columns = list (
187+ set (additional .columns .tolist ())
188+ - set (operator_config .spec .target_category_columns )
189+ - {operator_config .spec .datetime_column .name }
190+ )
181191 for column in additional_columns :
182- assert column in global_explanations .columns , f"Column { column } missing in global explanations"
183- assert column in local_explanations .columns , f"Column { column } missing in local explanations"
192+ assert (
193+ column in global_explanations .T .columns
194+ ), f"Column { column } missing in global explanations"
195+ assert (
196+ column in local_explanations .columns
197+ ), f"Column { column } missing in local explanations"
184198
185199
186200@pytest .mark .parametrize ("model" , MODELS )
@@ -208,11 +222,19 @@ def test_explanations_filenames(model, num_series):
208222
209223 results = forecast_operate (operator_config )
210224
211- global_explanation_path = os .path .join (output_directory , global_explanation_filename )
212- local_explanation_path = os .path .join (output_directory , local_explanation_filename )
225+ global_explanation_path = os .path .join (
226+ output_directory , global_explanation_filename
227+ )
228+ local_explanation_path = os .path .join (
229+ output_directory , local_explanation_filename
230+ )
213231
214- assert os .path .exists (global_explanation_path ), f"Global explanation file not found at { global_explanation_path } "
215- assert os .path .exists (local_explanation_path ), f"Local explanation file not found at { local_explanation_path } "
232+ assert os .path .exists (
233+ global_explanation_path
234+ ), f"Global explanation file not found at { global_explanation_path } "
235+ assert os .path .exists (
236+ local_explanation_path
237+ ), f"Local explanation file not found at { local_explanation_path } "
216238
217239
218240@pytest .mark .parametrize ("model" , MODELS )
@@ -231,19 +253,23 @@ def test_explanations_no_additional_data(model, num_series, caplog):
231253 with tempfile .TemporaryDirectory () as tmpdirname :
232254 output_directory = tmpdirname
233255
234- _ , _ , operator_config = setup_test_data (model , "D" , num_series , include_additional = False )
256+ _ , _ , operator_config = setup_test_data (
257+ model , "D" , num_series , include_additional = False
258+ )
235259 operator_config .spec .output_directory .url = output_directory
236260
237261 forecast_operate (operator_config )
238262
239263 assert any (
240264 "Unable to generate explanations as there is no additional data passed in. Either set generate_explanations to False, or pass in additional data."
241- in message for message in caplog .messages
265+ in message
266+ for message in caplog .messages
242267 ), "Required warning message not found in logs"
243268
244269
245270MODES = ["BALANCED" , "HIGH_ACCURACY" ]
246271
272+
247273@pytest .mark .skip (reason = "Disabled by default. Enable to run this test." )
248274@pytest .mark .parametrize ("mode" , MODES )
249275@pytest .mark .parametrize ("model" , MODELS )
@@ -269,11 +295,19 @@ def test_explanations_accuracy_mode(mode, model, num_series):
269295
270296 results = forecast_operate (operator_config )
271297
272- global_explanation_path = os .path .join (output_directory , operator_config .spec .global_explanation_filename )
273- local_explanation_path = os .path .join (output_directory , operator_config .spec .local_explanation_filename )
298+ global_explanation_path = os .path .join (
299+ output_directory , operator_config .spec .global_explanation_filename
300+ )
301+ local_explanation_path = os .path .join (
302+ output_directory , operator_config .spec .local_explanation_filename
303+ )
274304
275- assert os .path .exists (global_explanation_path ), f"Global explanation file not found at { global_explanation_path } "
276- assert os .path .exists (local_explanation_path ), f"Local explanation file not found at { local_explanation_path } "
305+ assert os .path .exists (
306+ global_explanation_path
307+ ), f"Global explanation file not found at { global_explanation_path } "
308+ assert os .path .exists (
309+ local_explanation_path
310+ ), f"Local explanation file not found at { local_explanation_path } "
277311
278312
279313@pytest .mark .parametrize ("model" , MODELS )
0 commit comments