12
12
13
13
import sklearn
14
14
import sklearn .datasets
15
+ from sklearn .base import clone
15
16
from sklearn .ensemble import VotingClassifier , VotingRegressor
16
17
18
+ from smac .runhistory .runhistory import RunHistory
19
+
17
20
import torch
18
21
19
22
from autoPyTorch .api .tabular_classification import TabularClassificationTask
23
26
HoldoutValTypes ,
24
27
)
25
28
from autoPyTorch .optimizer .smbo import AutoMLSMBO
29
+ from autoPyTorch .pipeline .components .training .metrics .metrics import accuracy
26
30
27
31
28
32
# Fixtures
@@ -104,17 +108,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
104
108
105
109
# Search for an existing run key in disc. A individual model might have
106
110
# a timeout and hence was not written to disc
111
+ successful_num_run = None
112
+ SUCCESS = False
107
113
for i , (run_key , value ) in enumerate (estimator .run_history .data .items ()):
108
- if 'SUCCESS' not in str (value .status ):
109
- continue
110
-
111
- run_key_model_run_dir = estimator ._backend .get_numrun_directory (
112
- estimator .seed , run_key .config_id + 1 , run_key .budget )
113
- if os .path .exists (run_key_model_run_dir ):
114
- # Runkey config id is different from the num_run
115
- # more specifically num_run = config_id + 1(dummy)
114
+ if 'SUCCESS' in str (value .status ):
115
+ run_key_model_run_dir = estimator ._backend .get_numrun_directory (
116
+ estimator .seed , run_key .config_id + 1 , run_key .budget )
116
117
successful_num_run = run_key .config_id + 1
117
- break
118
+ if os .path .exists (run_key_model_run_dir ):
119
+ # Runkey config id is different from the num_run
120
+ # more specifically num_run = config_id + 1(dummy)
121
+ SUCCESS = True
122
+ break
123
+
124
+ assert SUCCESS , f"Successful run was not properly saved for num_run: { successful_num_run } "
118
125
119
126
if resampling_strategy == HoldoutValTypes .holdout_validation :
120
127
model_file = os .path .join (run_key_model_run_dir ,
@@ -272,17 +279,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
272
279
273
280
# Search for an existing run key in disc. A individual model might have
274
281
# a timeout and hence was not written to disc
282
+ successful_num_run = None
283
+ SUCCESS = False
275
284
for i , (run_key , value ) in enumerate (estimator .run_history .data .items ()):
276
- if 'SUCCESS' not in str (value .status ):
277
- continue
278
-
279
- run_key_model_run_dir = estimator ._backend .get_numrun_directory (
280
- estimator .seed , run_key .config_id + 1 , run_key .budget )
281
- if os .path .exists (run_key_model_run_dir ):
282
- # Runkey config id is different from the num_run
283
- # more specifically num_run = config_id + 1(dummy)
285
+ if 'SUCCESS' in str (value .status ):
286
+ run_key_model_run_dir = estimator ._backend .get_numrun_directory (
287
+ estimator .seed , run_key .config_id + 1 , run_key .budget )
284
288
successful_num_run = run_key .config_id + 1
285
- break
289
+ if os .path .exists (run_key_model_run_dir ):
290
+ # Runkey config id is different from the num_run
291
+ # more specifically num_run = config_id + 1(dummy)
292
+ SUCCESS = True
293
+ break
294
+
295
+ assert SUCCESS , f"Successful run was not properly saved for num_run: { successful_num_run } "
286
296
287
297
if resampling_strategy == HoldoutValTypes .holdout_validation :
288
298
model_file = os .path .join (run_key_model_run_dir ,
@@ -384,7 +394,7 @@ def test_tabular_input_support(openml_id, backend):
384
394
estimator ._do_dummy_prediction = unittest .mock .MagicMock ()
385
395
386
396
with unittest .mock .patch .object (AutoMLSMBO , 'run_smbo' ) as AutoMLSMBOMock :
387
- AutoMLSMBOMock .return_value = ({} , {}, 'epochs' )
397
+ AutoMLSMBOMock .return_value = (RunHistory () , {}, 'epochs' )
388
398
estimator .search (
389
399
X_train = X_train , y_train = y_train ,
390
400
X_test = X_test , y_test = y_test ,
@@ -394,3 +404,48 @@ def test_tabular_input_support(openml_id, backend):
394
404
enable_traditional_pipeline = False ,
395
405
load_models = False ,
396
406
)
407
+
408
+
409
+ @pytest .mark .parametrize ("fit_dictionary_tabular" , ['classification_categorical_only' ], indirect = True )
410
+ def test_do_dummy_prediction (dask_client , fit_dictionary_tabular ):
411
+ backend = fit_dictionary_tabular ['backend' ]
412
+ estimator = TabularClassificationTask (
413
+ backend = backend ,
414
+ resampling_strategy = HoldoutValTypes .holdout_validation ,
415
+ ensemble_size = 0 ,
416
+ )
417
+
418
+ # Setup pre-requisites normally set by search()
419
+ estimator ._create_dask_client ()
420
+ estimator ._metric = accuracy
421
+ estimator ._logger = estimator ._get_logger ('test' )
422
+ estimator ._memory_limit = 5000
423
+ estimator ._time_for_task = 60
424
+ estimator ._disable_file_output = []
425
+ estimator ._all_supported_metrics = False
426
+
427
+ estimator ._do_dummy_prediction ()
428
+
429
+ # Ensure that the dummy predictions are not in the current working
430
+ # directory, but in the temporary directory.
431
+ assert not os .path .exists (os .path .join (os .getcwd (), '.autoPyTorch' ))
432
+ assert os .path .exists (os .path .join (
433
+ backend .temporary_directory , '.autoPyTorch' , 'runs' , '1_1_1.0' ,
434
+ 'predictions_ensemble_1_1_1.0.npy' )
435
+ )
436
+
437
+ model_path = os .path .join (backend .temporary_directory ,
438
+ '.autoPyTorch' ,
439
+ 'runs' , '1_1_1.0' ,
440
+ '1.1.1.0.model' )
441
+
442
+ # Make sure the dummy model complies with scikit learn
443
+ # get/set params
444
+ assert os .path .exists (model_path )
445
+ with open (model_path , 'rb' ) as model_handler :
446
+ clone (pickle .load (model_handler ))
447
+
448
+ estimator ._close_dask_client ()
449
+ estimator ._clean_logger ()
450
+
451
+ del estimator
0 commit comments