Skip to content

Commit 5fef094

Browse files
[FIX] Enables backend to track the num run (#162)
* AA_151 * doc the peek attr * [ADD] Relax constant pipeline performance
1 parent 45076ee commit 5fef094

File tree

6 files changed

+140
-29
lines changed

6 files changed

+140
-29
lines changed

autoPyTorch/api/base_task.py

+17-25
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,14 @@ def _load_best_individual_model(self) -> SingleBest:
486486

487487
return ensemble
488488

489-
def _do_dummy_prediction(self, num_run: int) -> None:
489+
def _do_dummy_prediction(self) -> None:
490490

491491
assert self._metric is not None
492492
assert self._logger is not None
493493

494+
# For dummy estimator, we always expect the num_run to be 1
495+
num_run = 1
496+
494497
self._logger.info("Starting to create dummy predictions.")
495498

496499
memory_limit = self._memory_limit
@@ -551,29 +554,20 @@ def _do_dummy_prediction(self, num_run: int) -> None:
551554
% (str(status), str(additional_info))
552555
)
553556

554-
def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_time_limit_secs: int
555-
) -> int:
557+
def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: int) -> None:
556558
"""
557559
Fits traditional machine learning algorithms to the provided dataset, while
558560
complying with time resource allocation.
559561
560562
This method currently only supports classification.
561563
562564
Args:
563-
num_run: (int)
564-
An identifier to indicate the current machine learning algorithm
565-
being processed
566565
time_left: (int)
567566
Hard limit on how many machine learning algorithms can be fit. Depending on how
568567
fast a traditional machine learning algorithm trains, it will allow multiple
569568
models to be fitted.
570569
func_eval_time_limit_secs: (int)
571570
Maximum training time each algorithm is allowed to take, during training
572-
573-
Returns:
574-
num_run: (int)
575-
The incremented identifier index. This depends on how many machine learning
576-
models were fitted.
577571
"""
578572

579573
# Mypy Checkings -- Traditional prediction is only called for search
@@ -592,8 +586,8 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
592586
available_classifiers = get_available_classifiers()
593587
dask_futures = []
594588

595-
total_number_classifiers = len(available_classifiers) + num_run
596-
for n_r, classifier in enumerate(available_classifiers, start=num_run):
589+
total_number_classifiers = len(available_classifiers)
590+
for n_r, classifier in enumerate(available_classifiers):
597591

598592
# Only launch a task if there is time
599593
start_time = time.time()
@@ -612,7 +606,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
612606
logger_port=self._logger_port,
613607
cost_for_crash=get_cost_of_crash(self._metric),
614608
abort_on_first_run_crash=False,
615-
initial_num_run=n_r,
609+
initial_num_run=self._backend.get_next_num_run(),
616610
stats=stats,
617611
memory_limit=memory_limit,
618612
disable_file_output=True if len(self._disable_file_output) > 0 else False,
@@ -626,9 +620,6 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
626620
)
627621
])
628622

629-
# Increment the launched job index
630-
num_run = n_r
631-
632623
# When managing time, we need to take into account the allocated time resources,
633624
# which are dependent on the number of cores. 'dask_futures' is a proxy to the number
634625
# of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most
@@ -691,7 +682,7 @@ def _do_traditional_prediction(self, num_run: int, time_left: int, func_eval_tim
691682
self.run_history.update(run_history, DataOrigin.EXTERNAL_SAME_INSTANCES)
692683
run_history.save_json(os.path.join(self._backend.internals_directory, 'traditional_run_history.json'),
693684
save_external=True)
694-
return num_run
685+
return
695686

696687
def _search(
697688
self,
@@ -861,10 +852,9 @@ def _search(
861852
)
862853

863854
# ============> Run dummy predictions
864-
num_run = 1
865855
dummy_task_name = 'runDummy'
866856
self._stopwatch.start_task(dummy_task_name)
867-
self._do_dummy_prediction(num_run)
857+
self._do_dummy_prediction()
868858
self._stopwatch.stop_task(dummy_task_name)
869859

870860
# ============> Run traditional ml
@@ -880,8 +870,8 @@ def _search(
880870
time_for_traditional = int(
881871
self._time_for_task - elapsed_time - func_eval_time_limit_secs
882872
)
883-
num_run = self._do_traditional_prediction(
884-
num_run=num_run + 1, func_eval_time_limit_secs=func_eval_time_limit_secs,
873+
self._do_traditional_prediction(
874+
func_eval_time_limit_secs=func_eval_time_limit_secs,
885875
time_left=time_for_traditional,
886876
)
887877
self._stopwatch.stop_task(traditional_task_name)
@@ -957,7 +947,9 @@ def _search(
957947
pipeline_config={**self.pipeline_options, **budget_config},
958948
ensemble_callback=proc_ensemble,
959949
logger_port=self._logger_port,
960-
start_num_run=num_run,
950+
# We do not increase the num_run here, this is something
951+
# smac does internally
952+
start_num_run=self._backend.get_next_num_run(peek=True),
961953
search_space_updates=self.search_space_updates
962954
)
963955
try:
@@ -1063,7 +1055,7 @@ def refit(
10631055
'train_indices': dataset.splits[split_id][0],
10641056
'val_indices': dataset.splits[split_id][1],
10651057
'split_id': split_id,
1066-
'num_run': 0
1058+
'num_run': self._backend.get_next_num_run(),
10671059
})
10681060
X.update({**self.pipeline_options, **budget_config})
10691061
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
@@ -1140,7 +1132,7 @@ def fit(self,
11401132
'train_indices': dataset.splits[split_id][0],
11411133
'val_indices': dataset.splits[split_id][1],
11421134
'split_id': split_id,
1143-
'num_run': 0
1135+
'num_run': self._backend.get_next_num_run(),
11441136
})
11451137
X.update({**self.pipeline_options, **budget_config})
11461138

autoPyTorch/evaluation/abstract_evaluator.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def __init__(self, config: Configuration,
144144
random_state: Optional[Union[int, np.random.RandomState]] = None,
145145
init_params: Optional[Dict] = None
146146
) -> None:
147-
self.configuration = config
147+
self.config = config
148+
self.init_params = init_params
149+
self.random_state = random_state
148150
if config == 1:
149151
super(DummyClassificationPipeline, self).__init__(strategy="uniform")
150152
else:
@@ -208,7 +210,9 @@ class DummyRegressionPipeline(DummyRegressor):
208210
def __init__(self, config: Configuration,
209211
random_state: Optional[Union[int, np.random.RandomState]] = None,
210212
init_params: Optional[Dict] = None) -> None:
211-
self.configuration = config
213+
self.config = config
214+
self.init_params = init_params
215+
self.random_state = random_state
212216
if config == 1:
213217
super(DummyRegressionPipeline, self).__init__(strategy='mean')
214218
else:

autoPyTorch/utils/backend.py

+45
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def __init__(self, context: BackendContext):
169169
self._logger = None # type: Optional[PicklableClientLogger]
170170
self.context = context
171171

172+
# Track the number of configurations launched
173+
# num_run == 1 means a dummy estimator run
174+
self.active_num_run = 1
175+
172176
# Create the temporary directory if it does not yet exist
173177
try:
174178
os.makedirs(self.temporary_directory)
@@ -329,6 +333,47 @@ def get_runs_directory(self) -> str:
329333
def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str:
330334
return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget))
331335

336+
def get_next_num_run(self, peek: bool = False) -> int:
337+
"""
338+
Every pipeline that is fitted by the estimator is stored with an
339+
identifier called num_run. A dummy classifier will always have a num_run
340+
equal to 1, and all other new configurations that are explored will
341+
have a sequentially increasing identifier.
342+
343+
This method returns the next num_run a configuration should take.
344+
345+
Parameters
346+
----------
347+
peek: bool
348+
By default, the next num_rum will be returned, i.e. self.active_num_run + 1
349+
Yet, if this bool parameter is equal to True, the value of the current
350+
num_run is provided, i.e, self.active_num_run.
351+
In other words, peek allows to get the current maximum identifier
352+
of a configuration.
353+
354+
Returns
355+
-------
356+
num_run: int
357+
An unique identifier for a configuration
358+
"""
359+
360+
# If there are other num_runs, their name would be runs/<seed>_<num_run>_<budget>
361+
other_num_runs = [int(os.path.basename(run_dir).split('_')[1])
362+
for run_dir in glob.glob(os.path.join(self.internals_directory,
363+
'runs',
364+
'*'))]
365+
if len(other_num_runs) > 0:
366+
# We track the number of runs from two forefronts:
367+
# The physically available num_runs (which might be deleted or a crash could happen)
368+
# From a internally kept attribute. The later should be sufficient, but we
369+
# want to be robust against multiple backend copies on different workers
370+
self.active_num_run = max([self.active_num_run] + other_num_runs)
371+
372+
# We are interested in the next run id
373+
if not peek:
374+
self.active_num_run += 1
375+
return self.active_num_run
376+
332377
def get_model_filename(self, seed: int, idx: int, budget: float) -> str:
333378
return '%s.%s.%s.model' % (seed, idx, budget)
334379

test/test_api/test_api.py

+47
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import sklearn
1414
import sklearn.datasets
15+
from sklearn.base import clone
1516
from sklearn.ensemble import VotingClassifier, VotingRegressor
1617

1718
from smac.runhistory.runhistory import RunHistory
@@ -25,6 +26,7 @@
2526
HoldoutValTypes,
2627
)
2728
from autoPyTorch.optimizer.smbo import AutoMLSMBO
29+
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
2830

2931

3032
# Fixtures
@@ -402,3 +404,48 @@ def test_tabular_input_support(openml_id, backend):
402404
enable_traditional_pipeline=False,
403405
load_models=False,
404406
)
407+
408+
409+
@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True)
410+
def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
411+
backend = fit_dictionary_tabular['backend']
412+
estimator = TabularClassificationTask(
413+
backend=backend,
414+
resampling_strategy=HoldoutValTypes.holdout_validation,
415+
ensemble_size=0,
416+
)
417+
418+
# Setup pre-requisites normally set by search()
419+
estimator._create_dask_client()
420+
estimator._metric = accuracy
421+
estimator._logger = estimator._get_logger('test')
422+
estimator._memory_limit = 5000
423+
estimator._time_for_task = 60
424+
estimator._disable_file_output = []
425+
estimator._all_supported_metrics = False
426+
427+
estimator._do_dummy_prediction()
428+
429+
# Ensure that the dummy predictions are not in the current working
430+
# directory, but in the temporary directory.
431+
assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch'))
432+
assert os.path.exists(os.path.join(
433+
backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0',
434+
'predictions_ensemble_1_1_1.0.npy')
435+
)
436+
437+
model_path = os.path.join(backend.temporary_directory,
438+
'.autoPyTorch',
439+
'runs', '1_1_1.0',
440+
'1.1.1.0.model')
441+
442+
# Make sure the dummy model complies with scikit learn
443+
# get/set params
444+
assert os.path.exists(model_path)
445+
with open(model_path, 'rb') as model_handler:
446+
clone(pickle.load(model_handler))
447+
448+
estimator._close_dask_client()
449+
estimator._clean_logger()
450+
451+
del estimator

test/test_pipeline/test_tabular_classification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -439,5 +439,5 @@ def test_constant_pipeline_iris(fit_dictionary_tabular):
439439
val_score = run_summary.performance_tracker['val_metrics'][epoch_where_best]['balanced_accuracy']
440440
train_score = run_summary.performance_tracker['train_metrics'][epoch_where_best]['balanced_accuracy']
441441

442-
assert val_score >= 0.9, run_summary.performance_tracker['val_metrics']
443-
assert train_score >= 0.9, run_summary.performance_tracker['train_metrics']
442+
assert val_score >= 0.8, run_summary.performance_tracker['val_metrics']
443+
assert train_score >= 0.8, run_summary.performance_tracker['train_metrics']

test/test_utils/test_backend.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# -*- encoding: utf-8 -*-
22
import builtins
3+
import logging.handlers
34
import unittest
45
import unittest.mock
56

7+
import numpy as np
8+
69
import pytest
710

811
from autoPyTorch.utils.backend import Backend
@@ -81,3 +84,23 @@ def test_loads_models_by_identifiers(exists_mock, openMock, pickleLoadMock, back
8184

8285
assert isinstance(actual_dict, dict)
8386
assert expected_dict == actual_dict
87+
88+
89+
def test_get_next_num_run(backend):
90+
# Asking for a num_run increases the tracked num_run
91+
assert backend.get_next_num_run() == 2
92+
assert backend.get_next_num_run() == 3
93+
# Then test that we are robust against new files being generated
94+
backend.setup_logger('Test', logging.handlers.DEFAULT_TCP_LOGGING_PORT)
95+
backend.save_numrun_to_dir(
96+
seed=0,
97+
idx=12,
98+
budget=0.0,
99+
model=dict(),
100+
cv_model=None,
101+
ensemble_predictions=np.zeros(10),
102+
valid_predictions=None,
103+
test_predictions=None,
104+
)
105+
assert backend.get_next_num_run() == 13
106+
assert backend.get_next_num_run(peek=True) == 13

0 commit comments

Comments
 (0)