__all__= ['DecoderMng','DecoderException'] import threading,grpc,collections,traceback import logging as Log import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc import mllp_grpc.asr_system_pb2 as asr_system_pb2 import mllp_grpc.asr_common_pb2 as asr_common_pb2 import google.protobuf.empty_pb2 as empty_pb2 System= collections.namedtuple('System',['info','hosts']) SystemInfo= collections.namedtuple('SystemInfo', ['id','info','num_decoders', 'num_decoders_available']) HostInfo= collections.namedtuple('HostInfo', ['info','num_decoders', 'num_decoders_available', 'port','host']) def II(txt): Log.info('(DecoderMng) %s'%txt) class DecoderException(Exception): def __init__(self,code,msg): self.code= code self.msg= msg class LazzyIterator: def __init__(self,i): self._i= i self._cv= threading.Condition() self._ready= False def __call__(self): with self._cv: if not self._ready: self._cv.wait() for v in self._i: yield v def ready(self): with self._cv: self._ready= True self._cv.notify() # end LazzyIterator class DecoderMng: def __init__(self): self._lock= threading.Lock() self._sys_map= {} # System id -> int self._sys= [] # Llista de sistemes # NOTA! Pot generar excepcions. # Torna la info d'un sistema def __get_sys_info(self,host,port,ii=II): try: addr= '%s:%d'%(host,port) with grpc.insecure_channel(addr) as channel: stub= asr_system_pb2_grpc.ASRSystemStub(channel) return stub.GetSystemInfo(empty_pb2.Empty()) except: ii('unable to connect to %s:%d'%(host,port)) raise Exception('%s:%d is not a valid ASR System server'%(host,port)) def __remove_host(self,sys_id,host,port,ii=II): with self._lock: ii(('trying to remove host %s:%d'+ ' from system %s')%(host,port, self._sys[sys_id].info.id)) try: del self._sys[sys_id].hosts[(host,port)] ii('host removed') except: pass # Donat un sys_id (un int) torna un iterador sobre parells: # connexió al hosts d'este sistema i info. Si algun host falla al # connectar-se o el seu identificador no correspon amb el que # deuria, el host és ignorat i eliminat de la llista de # hosts. Aquest iterador no genera excepcions. def __iter_sys_stubs(self,sys_id,ii=II): with self._lock: system= self._sys[sys_id] hosts= list(system.hosts.keys()) for host,port in hosts: # Prova connexió i comprova id try: addr= '%s:%d'%(host,port) with grpc.insecure_channel(addr) as channel: stub= asr_system_pb2_grpc.ASRSystemStub(channel) info= stub.GetSystemInfo(empty_pb2.Empty()) if info.info.id!=system.info.id: ii(('bad identifier for host %s:%d:'+ ' %s != %s')%(host,port,info.info.id, system.info.id)) self.__remove_host(sys_id,host,port,ii) elif info.info.tag!=system.info.tag: ii(('bad tag for host %s:%d:'+ ' %s != %s')%(host,port,info.info.tag, system.info.tag)) self.__remove_host(sys_id,host,port,ii) else: yield stub,info,(host,port) except grpc.RpcError: traceback.print_exc() ii('unable to connect to %s:%d'%(host,port)) self.__remove_host(sys_id,host,port,ii) except GeneratorExit: pass except: traceback.print_exc() ii('unknown error') # end __iter_sys_stubs def __num_connexions(self,sys_id,ii=II): ret= 0 for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii): ret+= 1 return ret # NOTA! Pot generar excepcions. def add(self,host,port,ii=II): info= self.__get_sys_info(host,port,ii) with self._lock: sys_id= self._sys_map.get(info.info.id) numcon= 0 if sys_id is None else self.__num_connexions(sys_id) # Add host with self._lock: # Repetit a propòsit!!! sys_id= self._sys_map.get(info.info.id) if sys_id==None: sys_id= len(self._sys) self._sys_map[info.info.id]= sys_id sys= System(info=info.info,hosts={}) self._sys.append(sys) elif self._sys[sys_id].info.tag!=info.info.tag: system= self._sys[sys_id] hosts= list(system.hosts.keys()) if len(system.hosts)>0 and numcon>0: raise RuntimeError( 'tag mismatch %s != %s'%(self._sys[sys_id].info.tag, info.info.tag)) else: self._sys[sys_id]= System(info=info.info,hosts={}) key= (host,port) self._sys[sys_id].hosts[key]= True # end add def get_system_info(self,system_id,ii=II): # Obté el sys_id associat a sytem_id sys_id,system= -1,... with self._lock: for tmp_sys_id,tmp_system in enumerate(self._sys): if tmp_system.info.id==system_id: sys_id= tmp_sys_id system= tmp_system break if sys_id==-1: return False,... # Obté info num_decoders= 0 num_decoders_available= 0 for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii): if sinfo.enabled: num_decoders+= sinfo.num_decoders num_decoders_available+= sinfo.num_decoders_available if num_decoders>0: return True,SystemInfo(info=system.info, num_decoders=num_decoders, num_decoders_available=num_decoders_available, id=sys_id) else: return False,... # end get_system_info def get_systems_info(self,ii=II): # Llista de systemes with self._lock: nsys= len(self._sys) # Itera sobre els sistemes for sys_id in range(nsys): # Get info with self._lock: system= self._sys[sys_id] # Pregunta a tots els hosts num_decoders= 0 num_decoders_available= 0 for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii): if sinfo.enabled: num_decoders+= sinfo.num_decoders num_decoders_available+= sinfo.num_decoders_available # Return if num_decoders>0: yield SystemInfo(info=system.info, num_decoders=num_decoders, num_decoders_available=num_decoders_available, id=sys_id) # end get_systems_info def get_hosts_info(self,ii=II): # Llista de systemes with self._lock: nsys= len(self._sys) # Itera sobre els sistemes for sys_id in range(nsys): # Get info with self._lock: system= self._sys[sys_id] # Pregunta a tots els hosts for stub,sinfo,addr in self.__iter_sys_stubs(sys_id,ii): yield HostInfo(info=system.info, num_decoders=sinfo.num_decoders, num_decoders_available=sinfo.num_decoders_available, host=addr[0], port=addr[1]) # end get_hosts_info # NOTA!!! Pot generar excepcions de tipus DecoderException o # Exception. def decode(self,request_iterator,ii): def ee_format(txt): code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT raise DecoderException(code=code,msg=txt) def ee_unk_system(sys_id): code=asr_common_pb2.DecodeResponse.Status.Code.ERR_UNKNOWN_SYSTEM raise DecoderException(code=code,msg='%d is not a valid ASR system'%sys_id) def transform_data(iterator): try: for o in iterator: if o.HasField('token'): yield asr_common_pb2.DataPackage(token=o.token) elif not o.HasField('data'): ee_format(('a system_id package was provided'+ ' when a data package was expected')) else: yield asr_common_pb2.DataPackage(data=o.data) except grpc.RpcError: ii(('an expected error ocurred while reading input data,'+ ' probably connexion was closed by client')) # Selecciona sistema p= next(request_iterator,None) if p is None: ee_format('no packages were sent for decoding') if not p.HasField('system_id'): ee_format('First package must contain a system identifier') try: with self._lock: system= self._sys[p.system_id] except: ee_unk_system(p.system_id) ii('starting decoding with system: %s'%system.info.id) # Find for the first available system and decode add_host_info= True input_data= LazzyIterator(transform_data(request_iterator)) for stub,sinfo,addr in self.__iter_sys_stubs(p.system_id,ii): if sinfo.enabled and sinfo.num_decoders_available>0: ok= False for o in stub.Decode(input_data()): if o.status.code==asr_common_pb2.DecodeResponse.Status.READY: ii('connexion stablished with %s:%d'%(addr[0],addr[1])) input_data.ready() ok= True continue elif o.status.code==asr_common_pb2.DecodeResponse.Status.ERR_NO_RECO_AVAILABLE: break # Abort decoding if add_host_info: tmp= asr_common_pb2.DecodeResponse.HostInfo(host=addr[0],port=addr[1]) new_o= asr_common_pb2.DecodeResponse(status=o.status, hyp_novar=o.hyp_novar, hyp_var=o.hyp_var, eos=o.eos, score=o.score, nframes=o.nframes, host_info=tmp) add_host_info= False o= new_o yield o if ok: ii('decoding finished') return # Finish the decoding process # If this point is reached it means no decoder was available code=asr_common_pb2.DecodeResponse.Status.Code.ERR_NO_RECO_AVAILABLE raise DecoderException(code=code, msg=("no recogniser available for ASR"+ " system '%s'")%system.info.id) # end decode # end DecoderMng