diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 92a15d985..7d8c43302 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -17,7 +17,6 @@ import torch import numpy as np import threading -import time from torch.nn import functional as F from contextlib import nullcontext import uuid @@ -57,6 +56,7 @@ def __init__(self, # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.token_condition_dict = {} self.mel_overlap_dict = {} self.flow_cache_dict = {} self.hift_cache_dict = {} @@ -125,12 +125,18 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui continue else: cur_silent_token_num = 0 - self.tts_speech_token_dict[uuid].append(i) - self.llm_end_dict[uuid] = True + with self.lock: + self.tts_speech_token_dict[uuid].append(i) + self.token_condition_dict[uuid].notify() + with self.lock: + self.llm_end_dict[uuid] = True + self.token_condition_dict[uuid].notify() def vc_job(self, source_speech_token, uuid): - self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist() - self.llm_end_dict[uuid] = True + with self.lock: + self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist() + self.llm_end_dict[uuid] = True + self.token_condition_dict[uuid].notify() def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): with torch.cuda.amp.autocast(self.fp16): @@ -181,6 +187,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.token_condition_dict[this_uuid] = threading.Condition(self.lock) self.hift_cache_dict[this_uuid] = None self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0) self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2) @@ -192,10 +199,18 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze if stream is True: token_hop_len = self.token_min_hop_len while True: - time.sleep(0.1) - if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \ - .unsqueeze(dim=0) + with self.lock: + while len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len and \ + self.llm_end_dict[this_uuid] is False: + self.token_condition_dict[this_uuid].wait() + if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len: + this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len] + elif self.llm_end_dict[this_uuid] is True: + break + else: + continue + this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0) + if this_tts_speech_token.shape[1] != 0: this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -207,8 +222,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:] # increase token_hop_len for better speech quality token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor)) - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len: - break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) @@ -234,6 +247,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) + self.token_condition_dict.pop(this_uuid) self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) @@ -271,6 +285,7 @@ def __init__(self, # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.token_condition_dict = {} self.hift_cache_dict = {} self.silent_tokens = [] @@ -287,6 +302,10 @@ def load_vllm(self, model_dir): gpu_memory_utilization=0.2) self.llm.vllm = LLMEngine.from_engine_args(engine_args) self.llm.lock = threading.Lock() + self.llm.vllm_step_condition = threading.Condition(self.llm.lock) + self.llm.vllm_step_thread = None + self.llm.vllm_background_error = None + self.llm._ensure_vllm_runtime() del self.llm.llm.model.model.layers def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0): @@ -334,6 +353,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze this_uuid = str(uuid.uuid1()) with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.token_condition_dict[this_uuid] = threading.Condition(self.lock) self.hift_cache_dict[this_uuid] = None if source_speech_token.shape[1] == 0: p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid)) @@ -344,10 +364,19 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze token_offset = 0 prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1]) while True: - time.sleep(0.1) this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len - if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len: - this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) + required_token_len = token_offset + this_token_hop_len + self.flow.pre_lookahead_len + with self.lock: + while len(self.tts_speech_token_dict[this_uuid]) < required_token_len and self.llm_end_dict[this_uuid] is False: + self.token_condition_dict[this_uuid].wait() + if len(self.tts_speech_token_dict[this_uuid]) >= required_token_len: + this_tts_speech_token_slice = self.tts_speech_token_dict[this_uuid][:required_token_len] + elif self.llm_end_dict[this_uuid] is True: + break + else: + continue + this_tts_speech_token = torch.tensor(this_tts_speech_token_slice).unsqueeze(dim=0) + if this_tts_speech_token.shape[1] != 0: this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, prompt_feat=prompt_speech_feat, @@ -359,8 +388,6 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze token_offset += this_token_hop_len self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor) yield {'tts_speech': this_tts_speech.cpu()} - if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len: - break p.join() # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) @@ -370,6 +397,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze embedding=flow_embedding, token_offset=token_offset, uuid=this_uuid, + stream=stream, finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: @@ -388,9 +416,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze with self.lock: self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) + self.token_condition_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) if torch.cuda.is_available(): - torch.cuda.empty_cache() + # torch.cuda.empty_cache() torch.cuda.current_stream().synchronize() @@ -418,6 +447,7 @@ def __init__(self, # dict used to store session related variable self.tts_speech_token_dict = {} self.llm_end_dict = {} + self.token_condition_dict = {} self.hift_cache_dict = {} # FSQ silent and breath token self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323] diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index d3beb9ec2..681e9c05c 100644 --- a/cosyvoice/flow/flow_matching.py +++ b/cosyvoice/flow/flow_matching.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext + import torch import torch.nn.functional as F from matcha.models.components.flow_matching import BASECFM @@ -128,9 +130,11 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) else: [estimator, stream], trt_engine = self.estimator.acquire_estimator() - # NOTE need to synchronize when switching stream - torch.cuda.current_stream().synchronize() - with stream: + stream_context = stream if stream is not None else nullcontext() + if stream is not None: + # NOTE only synchronize when switching to a dedicated TRT stream. + torch.cuda.current_stream().synchronize() + with stream_context: estimator.set_input_shape('x', (2, 80, x.size(2))) estimator.set_input_shape('mask', (2, 1, x.size(2))) estimator.set_input_shape('mu', (2, 80, x.size(2))) @@ -148,7 +152,8 @@ def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) # run trt engine assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True - torch.cuda.current_stream().synchronize() + if stream is not None: + torch.cuda.current_stream().synchronize() self.estimator.release_estimator(estimator, stream) return x diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py index e8e81d942..b40399a80 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os, queue +import os +import hashlib +import queue import random -import time import threading from typing import Dict, Optional, Callable, List, Generator import numpy as np @@ -31,6 +32,62 @@ from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path +def _generation_env_int(name: str, default: int) -> int: + try: + return int(os.getenv(name, str(default))) + except ValueError: + return default + + +def _generation_request_seed(lm_input: torch.Tensor) -> int: + base_seed = _generation_env_int('COSYVOICE_SAMPLING_SEED', 0) + payload = lm_input.detach().float().contiguous().cpu().numpy().tobytes() + digest = hashlib.sha256(payload).digest() + return (int.from_bytes(digest[:8], byteorder='little', signed=False) ^ base_seed) & ((1 << 63) - 1) + + +def _build_generation_generator(lm_input: torch.Tensor): + seed = _generation_request_seed(lm_input) + generator = torch.Generator(device=lm_input.device if lm_input.is_cuda else 'cpu') + generator.manual_seed(seed) + return generator, seed + + +class CosyVoiceSamplingLogitsProcessor: + """Force vLLM to follow CosyVoice's sampling_ids / ras_sampling path.""" + + def __init__(self, sampling_fn: Callable, sampling: int, speech_token_size: int, min_tokens: int, + seed: Optional[int] = None): + self.sampling_fn = sampling_fn + self.sampling = sampling + self.speech_token_size = speech_token_size + self.min_tokens = min_tokens + self.seed = seed + self.generator = None + + def __call__(self, token_ids: list[int], logits: torch.Tensor) -> torch.Tensor: + if self.generator is None: + self.generator = torch.Generator(device=logits.device) + if self.seed is not None: + self.generator.manual_seed(self.seed) + num_trials, max_trials = 0, 100 + ignore_eos = len(token_ids) < self.min_tokens + while True: + top_ids = self.sampling_fn(logits, token_ids, self.sampling, generator=self.generator) + if (not ignore_eos) or (top_ids < self.speech_token_size): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError( + 'sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format( + max_trials + ) + ) + forced_logits = torch.full_like(logits, torch.finfo(logits.dtype).min) + forced_logits[top_ids] = 0 + return forced_logits + + class TransformerLM(torch.nn.Module): def __init__( self, @@ -153,10 +210,16 @@ def sampling_ids( decoded_tokens: List, sampling: int, ignore_eos: bool = True, + generator: Optional[torch.Generator] = None, ): - if ignore_eos is True: - weighted_scores[self.speech_token_size] = -float('inf') - top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling, generator=generator) + if (not ignore_eos) or (top_ids < self.speech_token_size): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials)) return top_ids @torch.inference_mode() @@ -296,9 +359,50 @@ def __init__( # 5. vllm related self.stop_token_ids = [speech_token_size + i for i in range(3)] self.vllm_output_queue = {} + self.vllm_finished_requests = set() + self.vllm_step_condition = None + self.vllm_step_thread = None + self.vllm_background_error = None if online_feature is True: self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx')) + def _ensure_vllm_runtime(self): + if self.vllm_background_error is not None: + raise self.vllm_background_error + if getattr(self, 'lock', None) is None: + self.lock = threading.Lock() + if self.vllm_step_condition is None: + self.vllm_step_condition = threading.Condition(self.lock) + if self.vllm_step_thread is None or self.vllm_step_thread.is_alive() is False: + self.vllm_step_thread = threading.Thread(target=self._vllm_step_loop, daemon=True) + self.vllm_step_thread.start() + + def _vllm_step_loop(self): + while True: + try: + with self.lock: + while len(self.vllm_output_queue) == 0 or all(i in self.vllm_finished_requests for i in self.vllm_output_queue): + self.vllm_step_condition.wait() + request_outputs = self.vllm.step() + for request_output in request_outputs: + if len(request_output.outputs) == 0: + continue + token_ids = request_output.outputs[0].token_ids + if len(token_ids) == 0: + continue + output_queue = self.vllm_output_queue.get(request_output.request_id, None) + if output_queue is not None: + top_id = token_ids[-1] + if top_id in self.stop_token_ids: + self.vllm_finished_requests.add(request_output.request_id) + output_queue.put(top_id) + except Exception as e: + with self.lock: + self.vllm_background_error = RuntimeError('vLLM step worker failed: {}'.format(e)) + for output_queue in self.vllm_output_queue.values(): + output_queue.put(self.vllm_background_error) + return + def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None): lm_target, lm_input = [], [] text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True) @@ -503,25 +607,41 @@ def inference( @torch.inference_mode() def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): + generator, seed = _build_generation_generator(lm_input) if hasattr(self, 'vllm'): from vllm import SamplingParams, RequestOutput + self._ensure_vllm_runtime() + logits_processor = CosyVoiceSamplingLogitsProcessor( + sampling_fn=self.sampling, + sampling=sampling, + speech_token_size=self.speech_token_size, + min_tokens=min_len, + seed=seed, + ) sampling_params = SamplingParams(top_k=sampling, stop_token_ids=self.stop_token_ids, - min_tokens=min_len, - max_tokens=max_len) + min_tokens=0, + max_tokens=max_len, + temperature=0.0, + top_p=1.0, + logits_processors=[logits_processor]) + output_queue = queue.Queue() with self.lock: - self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params) - self.vllm_output_queue[uuid] = queue.Queue() + self.vllm_output_queue[uuid] = output_queue + self.vllm_finished_requests.discard(uuid) + try: + self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params) + except Exception: + self.vllm_output_queue.pop(uuid, None) + self.vllm_finished_requests.discard(uuid) + raise + self.vllm_step_condition.notify() out_tokens = [] - while True: - with self.lock: - if self.vllm_output_queue[uuid].empty() is True: - request_outputs: List[RequestOutput] = self.vllm.step() - for request_output in request_outputs: - top_ids = list(request_output.outputs[0].token_ids)[-1] - self.vllm_output_queue[request_output.request_id].put(top_ids) - if self.vllm_output_queue[uuid].empty() is False: - top_ids = self.vllm_output_queue[uuid].get() + try: + while True: + top_ids = output_queue.get() + if isinstance(top_ids, Exception): + raise top_ids if top_ids in self.stop_token_ids: break # in stream mode, yield token one by one @@ -529,9 +649,11 @@ def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): out_tokens.append(top_ids) if len(out_tokens) == max_len: break - time.sleep(0.001) - with self.lock: - self.vllm_output_queue.pop(uuid) + finally: + with self.lock: + self.vllm_finished_requests.add(uuid) + self.vllm_output_queue.pop(uuid, None) + self.vllm_finished_requests.discard(uuid) else: out_tokens = [] cache = None @@ -540,7 +662,8 @@ def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False) + ignore_eos = True if i < min_len else False + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=ignore_eos, generator=generator) if top_ids in self.stop_token_ids: break # in stream mode, yield token one by one @@ -702,5 +825,9 @@ def __init__( # 5. vllm related self.stop_token_ids = [speech_token_size + i for i in range(200)] self.vllm_output_queue = {} + self.vllm_finished_requests = set() + self.vllm_step_condition = None + self.vllm_step_thread = None + self.vllm_background_error = None if online_feature is True: - self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx')) \ No newline at end of file + self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx')) diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 3f235a62e..f24686f3f 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -135,16 +135,16 @@ def init_weights(m, mean=0.0, std=0.01): # Repetition Aware Sampling in VALL-E 2 -def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): - top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) +def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1, generator=None): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k, generator=generator) rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() if rep_num >= win_size * tau_r: weighted_scores[top_ids] = -float('inf') - top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling, generator=generator) return top_ids -def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25, generator=None): prob, indices = [], [] cum_prob = 0.0 sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) @@ -158,24 +158,22 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): break prob = torch.tensor(prob).to(weighted_scores) indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) - top_ids = indices[prob.multinomial(1, replacement=True)].item() + top_ids = indices[prob.multinomial(1, replacement=True, generator=generator)].item() return top_ids -def random_sampling(weighted_scores, decoded_tokens, sampling): - top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item() +def random_sampling(weighted_scores, decoded_tokens, sampling, generator=None): + top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True, generator=generator).item() return top_ids def fade_in_out(fade_in_mel, fade_out_mel, window): - device = fade_in_mel.device - fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() mel_overlap_len = int(window.shape[0] / 2) - if fade_in_mel.device == torch.device('cpu'): - fade_in_mel = fade_in_mel.clone() + fade_in_mel = fade_in_mel.clone() + window = torch.as_tensor(window, device=fade_in_mel.device, dtype=fade_in_mel.dtype) fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] - return fade_in_mel.to(device) + return fade_in_mel def set_all_random_seed(seed): @@ -202,7 +200,7 @@ def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_engine = trt_engine for _ in range(trt_concurrent): trt_context = trt_engine.create_execution_context() - trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) + trt_stream = None if trt_concurrent == 1 else torch.cuda.stream(torch.cuda.Stream(device)) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'