123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- __all__= ['decode']
- import concurrent.futures as cf
- import threading
- import mllp_grpc.asr_pb2 as asr_pb2
- import mllp_grpc.asr_common_pb2 as asr_common_pb2
- from server_lib.decoder_mng import DecoderException
- def ee_format(txt):
- code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT
- raise DecoderException(code=code,msg=txt)
- class Demultiplexor:
- def __init__(self,request_iterator):
- self._qs= []
- self.pos2sysid= []
- # Els primers paquets són els identificadors
- for o in request_iterator:
- if not o.HasField('system_id'):
- break
- new_q= [o]
- self._qs.append(new_q)
- self.pos2sysid.append(o.system_id)
- if self._qs==[]: ee_format('no system_id were provided')
- if not o.HasField('data'):
- ee_format('no data packages were provided')
- # Insert in all qs
- for q in self._qs:
- q.append(o)
- # Init state
- self._it= request_iterator
- self._stop= False
- self._counter= self.num_streams
- self._cv= threading.Condition()
-
- # end __init__
- def __refill_nonsecure(self):
- # Dec and check
- if self._stop: return
- self._counter-= 1
- if self._counter>0: return # L'últim replena
- # Fill
- o= next(self._it,None)
- if o is None: self._stop= True
- else:
- self._counter= self.num_streams
- for q in self._qs:
- q.append(o)
- self._cv.notify_all()
- def __refill(self):
- try:
- self.__refill_nonsecure()
- except BaseException as e:
- self._stop= True
- self._cv.notify_all()
- raise e
-
- # Se li passa el número de cola
- def __call__(self,q_id):
- while True:
- # Intenta obtindre el següent element
- with self._cv:
- q= self._qs[q_id]
- self._cv.wait_for(lambda: len(q)>0 or self._stop)
- if self._stop and len(q)==0: return # Finalitzem
- o= q.pop(0)
- if q==[]: self.__refill()
- yield o
- # end __call__
-
- @property
- def num_streams(self):
- return len(self._qs)
-
- # end Demultiplexor
- class SegmentLangId:
- def __init__(self,num_langs,seg_id=0):
- NL= num_langs
- self._seg_id= seg_id
- self._next= None
- self._max= [float('-inf')]*NL
- self._NL= NL
-
- def __get_best(self):
- lang,score= 0,self._max[0]
- for l in range(1,self._NL):
- val= self._max[l]
- if val>score: lang,score= l,val
- return lang
- # Torna el millor (seg_id,lang_id) del segment actual
- def __call__(self,lang_id,score,nframes):
- if nframes==0: return self._seg_id,self.__get_best()
- score/= nframes
- self._max[lang_id]= score
- return self._seg_id,self.__get_best()
- @property
- def next(self):
- if self._next==None:
- self._next= SegmentLangId(self._NL,self._seg_id+1)
- return self._next
-
- # end SegmentLangId
- class LangId:
- def __init__(self,num_langs):
- lid= SegmentLangId(num_langs)
- self._current_lang= []
- for i in range(num_langs):
- self._current_lang.append(lid)
- # Torna el millor (seg_id,lang_id) del segment actual
- def __call__(self,lang_id,score,nframes,is_final):
- current= self._current_lang[lang_id]
- ret= current(lang_id,score,nframes)
- if is_final:
- self._current_lang[lang_id]= current.next
- return ret
- # end HypMng
- class ExceptionContainer:
- def __init__(self):
- self._e= None
- def __bool__(self):
- return self._e is not None
- def set(self,e):
- self._e= e
- @property
- def exception(self):
- return self._e
- # end ExceptionContainer
- class Counter:
- def __init__(self,val):
- self.val= val
- def dec(self):
- self.val-= 1
-
- # end Counter
- # Pot generar excepcions de tot tipus. Incloent DecoderException.
- def decode(mng,request_iterator,ii):
- # Launch threads
- demux= Demultiplexor(request_iterator)
- output_buffer= []
- error= ExceptionContainer()
- cv= threading.Condition()
- lid= LangId(demux.num_streams)
- stop_counter= Counter(demux.num_streams)
- def transcribe(lang_id):
- def my_ii(txt):
- ii('[sys_id:%d] %s'%(demux.pos2sysid[lang_id],txt))
- stop= False
- for o in mng.decode(demux(lang_id),my_ii):
- with cv:
- seg_id,best_lid= lid(lang_id,o.score,o.nframes,o.eos)
- best_sys_id= demux.pos2sysid[best_lid]
- curr_sys_id= demux.pos2sysid[lang_id]
- res= asr_pb2.DecodeMultiResponse(res=o,segment_id=seg_id,
- best_system=best_sys_id,
- current_system=curr_sys_id)
- output_buffer.append(res)
- cv.notify_all()
- with cv:
- stop_counter.dec()
- cv.notify_all()
- def transcribe_err(lang_id):
- try:
- transcribe(lang_id)
- except BaseException as e:
- with cv:
- if not error:
- error.set(e)
- cv.notify_all()
- pool= cf.ThreadPoolExecutor(max_workers=demux.num_streams)
- for i in range(demux.num_streams):
- pool.submit(transcribe_err,i)
- # Processa l'exida
- stop= False
- while True:
- with cv:
- cv.wait_for(lambda: len(output_buffer)>0 or stop_counter.val==0 or error)
- if error: stop= True
- elif len(output_buffer)>0:
- ret= output_buffer.pop(0)
- else:
- assert stop_counter.val==0
- stop= True
- if stop: break
- yield ret
- # Acaba
- pool.shutdown()
- if error:
- raise error.exception
- # end decode
|