__all__= ['decode'] import concurrent.futures as cf import threading import mllp_grpc.asr_pb2 as asr_pb2 import mllp_grpc.asr_common_pb2 as asr_common_pb2 from server_lib.decoder_mng import DecoderException def ee_format(txt): code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT raise DecoderException(code=code,msg=txt) class Demultiplexor: def __init__(self,request_iterator): self._qs= [] self.pos2sysid= [] # Els primers paquets són els identificadors for o in request_iterator: if not o.HasField('system_id'): break new_q= [o] self._qs.append(new_q) self.pos2sysid.append(o.system_id) if self._qs==[]: ee_format('no system_id were provided') if not o.HasField('data'): ee_format('no data packages were provided') # Insert in all qs for q in self._qs: q.append(o) # Init state self._it= request_iterator self._stop= False self._counter= self.num_streams self._cv= threading.Condition() # end __init__ def __refill_nonsecure(self): # Dec and check if self._stop: return self._counter-= 1 if self._counter>0: return # L'últim replena # Fill o= next(self._it,None) if o is None: self._stop= True else: self._counter= self.num_streams for q in self._qs: q.append(o) self._cv.notify_all() def __refill(self): try: self.__refill_nonsecure() except BaseException as e: self._stop= True self._cv.notify_all() raise e # Se li passa el número de cola def __call__(self,q_id): while True: # Intenta obtindre el següent element with self._cv: q= self._qs[q_id] self._cv.wait_for(lambda: len(q)>0 or self._stop) if self._stop and len(q)==0: return # Finalitzem o= q.pop(0) if q==[]: self.__refill() yield o # end __call__ @property def num_streams(self): return len(self._qs) # end Demultiplexor class SegmentLangId: def __init__(self,num_langs,seg_id=0): NL= num_langs self._seg_id= seg_id self._next= None self._max= [float('-inf')]*NL self._NL= NL def __get_best(self): lang,score= 0,self._max[0] for l in range(1,self._NL): val= self._max[l] if val>score: lang,score= l,val return lang # Torna el millor (seg_id,lang_id) del segment actual def __call__(self,lang_id,score,nframes): if nframes==0: return self._seg_id,self.__get_best() score/= nframes self._max[lang_id]= score return self._seg_id,self.__get_best() @property def next(self): if self._next==None: self._next= SegmentLangId(self._NL,self._seg_id+1) return self._next # end SegmentLangId class LangId: def __init__(self,num_langs): lid= SegmentLangId(num_langs) self._current_lang= [] for i in range(num_langs): self._current_lang.append(lid) # Torna el millor (seg_id,lang_id) del segment actual def __call__(self,lang_id,score,nframes,is_final): current= self._current_lang[lang_id] ret= current(lang_id,score,nframes) if is_final: self._current_lang[lang_id]= current.next return ret # end HypMng class ExceptionContainer: def __init__(self): self._e= None def __bool__(self): return self._e is not None def set(self,e): self._e= e @property def exception(self): return self._e # end ExceptionContainer class Counter: def __init__(self,val): self.val= val def dec(self): self.val-= 1 # end Counter # Pot generar excepcions de tot tipus. Incloent DecoderException. def decode(mng,request_iterator,ii): # Launch threads demux= Demultiplexor(request_iterator) output_buffer= [] error= ExceptionContainer() cv= threading.Condition() lid= LangId(demux.num_streams) stop_counter= Counter(demux.num_streams) def transcribe(lang_id): def my_ii(txt): ii('[sys_id:%d] %s'%(demux.pos2sysid[lang_id],txt)) stop= False for o in mng.decode(demux(lang_id),my_ii): with cv: seg_id,best_lid= lid(lang_id,o.score,o.nframes,o.eos) best_sys_id= demux.pos2sysid[best_lid] curr_sys_id= demux.pos2sysid[lang_id] res= asr_pb2.DecodeMultiResponse(res=o,segment_id=seg_id, best_system=best_sys_id, current_system=curr_sys_id) output_buffer.append(res) cv.notify_all() with cv: stop_counter.dec() cv.notify_all() def transcribe_err(lang_id): try: transcribe(lang_id) except BaseException as e: with cv: if not error: error.set(e) cv.notify_all() pool= cf.ThreadPoolExecutor(max_workers=demux.num_streams) for i in range(demux.num_streams): pool.submit(transcribe_err,i) # Processa l'exida stop= False while True: with cv: cv.wait_for(lambda: len(output_buffer)>0 or stop_counter.val==0 or error) if error: stop= True elif len(output_buffer)>0: ret= output_buffer.pop(0) else: assert stop_counter.val==0 stop= True if stop: break yield ret # Acaba pool.shutdown() if error: raise error.exception # end decode