Skip to content

Commit b7726a8

Browse files
committed
[fix] Fix mypy issues
1 parent 9481437 commit b7726a8

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

autoPyTorch/api/base_task.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self._logger: Optional[PicklableClientLogger] = None
175175
self.run_history: RunHistory = RunHistory()
176176
self.trajectory: Optional[List] = None
177-
self.dataset_name: Optional[str] = None
177+
self.dataset_name: str = ""
178178
self.cv_models_: Dict = {}
179179
self.experiment_task_name: str = 'runSearch'
180180

@@ -702,16 +702,18 @@ def _run_traditional_ml(self) -> None:
702702
self._stopwatch.start_task(traditional_task_name)
703703
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
704704

705+
assert self._func_eval_time_limit_secs is not None
705706
time_for_traditional = int(
706707
self._total_walltime_limit - elapsed_time - self._func_eval_time_limit_secs
707708
)
708709
self._do_traditional_prediction(time_left=time_for_traditional)
709710
self._stopwatch.stop_task(traditional_task_name)
710711

711712
def _run_ensemble(self, dataset: BaseDataset, optimize_metric: str,
712-
precision: int) -> EnsembleBuilderManager:
713+
precision: int) -> Optional[EnsembleBuilderManager]:
713714

714715
assert self._logger is not None
716+
assert self._metric is not None
715717

716718
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
717719
time_left_for_ensembles = max(0, self._total_walltime_limit - elapsed_time)
@@ -788,7 +790,7 @@ def _start_smac(self, proc_smac: AutoMLSMBO) -> None:
788790
except Exception as e:
789791
self._logger.warning(f"Could not save {trajectory_filename} due to {e}...")
790792

791-
def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager,
793+
def _run_smac(self, dataset: BaseDataset, proc_ensemble: Optional[EnsembleBuilderManager],
792794
budget_type: Optional[str] = None, budget: Optional[float] = None,
793795
get_smac_object_callback: Optional[Callable] = None,
794796
smac_scenario_args: Optional[Dict[str, Any]] = None) -> None:
@@ -805,6 +807,9 @@ def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager,
805807
self._logger.warning(" Could not run SMAC because there is no time left")
806808
else:
807809
budget_config = self._get_budget_config(budget_type=budget_type, budget=budget)
810+
811+
assert self._func_eval_time_limit_secs is not None
812+
assert self._metric is not None
808813
proc_smac = AutoMLSMBO(
809814
config_space=self.search_space,
810815
dataset_name=dataset.dataset_name,
@@ -1095,7 +1100,7 @@ def refit(
10951100
Returns:
10961101
self
10971102
"""
1098-
if self.dataset_name is None:
1103+
if self.dataset_name == "":
10991104
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
11001105

11011106
if self._logger is None:
@@ -1165,7 +1170,7 @@ def fit(self,
11651170
Returns:
11661171
(BasePipeline): fitted pipeline
11671172
"""
1168-
if self.dataset_name is None:
1173+
if self.dataset_name == "":
11691174
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
11701175

11711176
if self._logger is None:

0 commit comments

Comments
 (0)