3030
3131MODELS = [
3232 "arima" ,
33- "automlx" ,
33+ # "automlx",
3434 "prophet" ,
3535 "neuralprophet" ,
3636]
@@ -140,9 +140,10 @@ def setup_test_data(model, freq, num_series, horizon=5, num_points=100, seed=42,
140140@pytest .mark .parametrize ("model" , MODELS )
141141@pytest .mark .parametrize ("freq" , ["D" , "W" , "M" , "H" , "T" ])
142142@pytest .mark .parametrize ("num_series" , [1 , 3 ])
143- def test_explanations_output (model , freq , num_series ):
143+ def test_explanations_output_and_columns (model , freq , num_series ):
144144 """
145145 Test the global and local explanations for different models, frequencies, and number of series.
146+ Also test that the explanation output contains all the columns from the additional dataset.
146147
147148 Parameters:
148149 - model: The forecasting model to use.
@@ -156,7 +157,7 @@ def test_explanations_output(model, freq, num_series):
156157 if model == "neuralprophet" :
157158 pytest .skip ("Skipping 'neuralprophet' model as it takes a long time to finish" )
158159
159- _ , _ , operator_config = setup_test_data (model , freq , num_series )
160+ _ , additional , operator_config = setup_test_data (model , freq , num_series )
160161
161162 results = forecast_operate (operator_config )
162163
@@ -176,33 +177,6 @@ def test_explanations_output(model, freq, num_series):
176177 not (local_explanations == 0 ).all ().all ()
177178 ), "Local explanations contain only 0 values"
178179
179-
180- @pytest .mark .parametrize ("model" , MODELS )
181- @pytest .mark .parametrize ("freq" , ["D" , "W" , "M" , "H" , "T" ])
182- @pytest .mark .parametrize ("num_series" , [1 , 3 ])
183- def test_explanations_columns (model , freq , num_series ):
184- """
185- Test that the explanation output contains all the columns from the additional dataset.
186-
187- Parameters:
188- - model: The forecasting model to use.
189- - freq: Frequency of the datetime column.
190- - num_series: Number of different time series to generate.
191- """
192- if model == "automlx" and freq == "T" :
193- pytest .skip (
194- "Skipping 'T' frequency for 'automlx' model. automlx requires data with a frequency of at least one hour"
195- )
196- if model == "neuralprophet" :
197- pytest .skip ("Skipping 'neuralprophet' model as it takes a long time to finish" )
198-
199- _ , additional , operator_config = setup_test_data (model , freq , num_series )
200-
201- results = forecast_operate (operator_config )
202-
203- global_explanations = results .get_global_explanations ()
204- local_explanations = results .get_local_explanations ()
205-
206180 additional_columns = additional .columns .tolist ()
207181 for column in additional_columns :
208182 assert column in global_explanations .columns , f"Column { column } missing in global explanations"
0 commit comments