123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- __all__= ['DecoderMng','DecoderException']
- import threading,grpc,collections,traceback
- import logging as Log
- import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
- import mllp_grpc.asr_system_pb2 as asr_system_pb2
- import mllp_grpc.asr_common_pb2 as asr_common_pb2
- import google.protobuf.empty_pb2 as empty_pb2
- System= collections.namedtuple('System',['info','hosts'])
- SystemInfo= collections.namedtuple('SystemInfo',
- ['id','info','num_decoders',
- 'num_decoders_available'])
- HostInfo= collections.namedtuple('HostInfo',
- ['info','num_decoders',
- 'num_decoders_available',
- 'port','host'])
- def II(txt): Log.info('(DecoderMng) %s'%txt)
- class DecoderException(Exception):
- def __init__(self,code,msg):
- self.code= code
- self.msg= msg
-
- class LazzyIterator:
-
- def __init__(self,i):
- self._i= i
- self._cv= threading.Condition()
- self._ready= False
-
- def __call__(self):
- with self._cv:
- if not self._ready:
- self._cv.wait()
- for v in self._i:
- yield v
-
- def ready(self):
- with self._cv:
- self._ready= True
- self._cv.notify()
- # end LazzyIterator
- class DecoderMng:
- def __init__(self):
- self._lock= threading.Lock()
- self._sys_map= {} # System id -> int
- self._sys= [] # Llista de sistemes
- # NOTA! Pot generar excepcions.
- # Torna la info d'un sistema
- def __get_sys_info(self,host,port,ii=II):
- try:
- addr= '%s:%d'%(host,port)
- with grpc.insecure_channel(addr) as channel:
- stub= asr_system_pb2_grpc.ASRSystemStub(channel)
- return stub.GetSystemInfo(empty_pb2.Empty())
- except:
- ii('unable to connect to %s:%d'%(host,port))
- raise Exception('%s:%d is not a valid ASR System server'%(host,port))
- def __remove_host(self,sys_id,host,port,ii=II):
- with self._lock:
- ii(('trying to remove host %s:%d'+
- ' from system %s')%(host,port,
- self._sys[sys_id].info.id))
- try:
- del self._sys[sys_id].hosts[(host,port)]
- ii('host removed')
- except:
- pass
-
- # Donat un sys_id (un int) torna un iterador sobre parells:
- # connexió al hosts d'este sistema i info. Si algun host falla al
- # connectar-se o el seu identificador no correspon amb el que
- # deuria, el host és ignorat i eliminat de la llista de
- # hosts. Aquest iterador no genera excepcions.
- def __iter_sys_stubs(self,sys_id,ii=II):
- with self._lock:
- system= self._sys[sys_id]
- hosts= list(system.hosts.keys())
- for host,port in hosts:
- # Prova connexió i comprova id
- try:
- addr= '%s:%d'%(host,port)
- with grpc.insecure_channel(addr) as channel:
- stub= asr_system_pb2_grpc.ASRSystemStub(channel)
- info= stub.GetSystemInfo(empty_pb2.Empty())
- if info.info.id!=system.info.id:
- ii(('bad identifier for host %s:%d:'+
- ' %s != %s')%(host,port,info.info.id,
- system.info.id))
- self.__remove_host(sys_id,host,port,ii)
- elif info.info.tag!=system.info.tag:
- ii(('bad tag for host %s:%d:'+
- ' %s != %s')%(host,port,info.info.tag,
- system.info.tag))
- self.__remove_host(sys_id,host,port,ii)
- else:
- yield stub,info,(host,port)
- except grpc.RpcError:
- traceback.print_exc()
- ii('unable to connect to %s:%d'%(host,port))
- self.__remove_host(sys_id,host,port,ii)
- except GeneratorExit:
- pass
- except:
- traceback.print_exc()
- ii('unknown error')
- # end __iter_sys_stubs
- def __num_connexions(self,sys_id,ii=II):
- ret= 0
- for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
- ret+= 1
- return ret
- # NOTA! Pot generar excepcions.
- def add(self,host,port,ii=II):
- info= self.__get_sys_info(host,port,ii)
- with self._lock:
- sys_id= self._sys_map.get(info.info.id)
- numcon= 0 if sys_id is None else self.__num_connexions(sys_id)
- # Add host
- with self._lock:
- # Repetit a propòsit!!!
- sys_id= self._sys_map.get(info.info.id)
- if sys_id==None:
- sys_id= len(self._sys)
- self._sys_map[info.info.id]= sys_id
- sys= System(info=info.info,hosts={})
- self._sys.append(sys)
- elif self._sys[sys_id].info.tag!=info.info.tag:
- system= self._sys[sys_id]
- hosts= list(system.hosts.keys())
- if len(system.hosts)>0 and numcon>0:
- raise RuntimeError(
- 'tag mismatch %s != %s'%(self._sys[sys_id].info.tag,
- info.info.tag))
- else:
- self._sys[sys_id]= System(info=info.info,hosts={})
- key= (host,port)
- self._sys[sys_id].hosts[key]= True
- # end add
- def get_system_info(self,system_id,ii=II):
- # Obté el sys_id associat a sytem_id
- sys_id,system= -1,...
- with self._lock:
- for tmp_sys_id,tmp_system in enumerate(self._sys):
- if tmp_system.info.id==system_id:
- sys_id= tmp_sys_id
- system= tmp_system
- break
- if sys_id==-1: return False,...
- # Obté info
- num_decoders= 0
- num_decoders_available= 0
- for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
- if sinfo.enabled:
- num_decoders+= sinfo.num_decoders
- num_decoders_available+= sinfo.num_decoders_available
- if num_decoders>0:
- return True,SystemInfo(info=system.info,
- num_decoders=num_decoders,
- num_decoders_available=num_decoders_available,
- id=sys_id)
- else: return False,...
- # end get_system_info
-
- def get_systems_info(self,ii=II):
- # Llista de systemes
- with self._lock:
- nsys= len(self._sys)
- # Itera sobre els sistemes
- for sys_id in range(nsys):
- # Get info
- with self._lock:
- system= self._sys[sys_id]
- # Pregunta a tots els hosts
- num_decoders= 0
- num_decoders_available= 0
- for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
- if sinfo.enabled:
- num_decoders+= sinfo.num_decoders
- num_decoders_available+= sinfo.num_decoders_available
- # Return
- if num_decoders>0:
- yield SystemInfo(info=system.info,
- num_decoders=num_decoders,
- num_decoders_available=num_decoders_available,
- id=sys_id)
- # end get_systems_info
- def get_hosts_info(self,ii=II):
- # Llista de systemes
- with self._lock:
- nsys= len(self._sys)
- # Itera sobre els sistemes
- for sys_id in range(nsys):
- # Get info
- with self._lock:
- system= self._sys[sys_id]
- # Pregunta a tots els hosts
- for stub,sinfo,addr in self.__iter_sys_stubs(sys_id,ii):
- yield HostInfo(info=system.info,
- num_decoders=sinfo.num_decoders,
- num_decoders_available=sinfo.num_decoders_available,
- host=addr[0],
- port=addr[1])
- # end get_hosts_info
-
- # NOTA!!! Pot generar excepcions de tipus DecoderException o
- # Exception.
- def decode(self,request_iterator,ii):
- def ee_format(txt):
- code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT
- raise DecoderException(code=code,msg=txt)
- def ee_unk_system(sys_id):
- code=asr_common_pb2.DecodeResponse.Status.Code.ERR_UNKNOWN_SYSTEM
- raise DecoderException(code=code,msg='%d is not a valid ASR system'%sys_id)
- def transform_data(iterator):
- try:
- for o in iterator:
- if o.HasField('token'):
- yield asr_common_pb2.DataPackage(token=o.token)
- elif not o.HasField('data'):
- ee_format(('a system_id package was provided'+
- ' when a data package was expected'))
- else:
- yield asr_common_pb2.DataPackage(data=o.data)
- except grpc.RpcError:
- ii(('an expected error ocurred while reading input data,'+
- ' probably connexion was closed by client'))
- # Selecciona sistema
- p= next(request_iterator,None)
- if p is None: ee_format('no packages were sent for decoding')
- if not p.HasField('system_id'):
- ee_format('First package must contain a system identifier')
- try:
- with self._lock:
- system= self._sys[p.system_id]
- except:
- ee_unk_system(p.system_id)
- ii('starting decoding with system: %s'%system.info.id)
- # Find for the first available system and decode
- add_host_info= True
- input_data= LazzyIterator(transform_data(request_iterator))
- for stub,sinfo,addr in self.__iter_sys_stubs(p.system_id,ii):
- if sinfo.enabled and sinfo.num_decoders_available>0:
- ok= False
- for o in stub.Decode(input_data()):
- if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
- ii('connexion stablished with %s:%d'%(addr[0],addr[1]))
- input_data.ready()
- ok= True
- continue
- elif o.status.code==asr_common_pb2.DecodeResponse.Status.ERR_NO_RECO_AVAILABLE:
- break # Abort decoding
- if add_host_info:
- tmp= asr_common_pb2.DecodeResponse.HostInfo(host=addr[0],port=addr[1])
- new_o= asr_common_pb2.DecodeResponse(status=o.status,
- hyp_novar=o.hyp_novar,
- hyp_var=o.hyp_var,
- eos=o.eos,
- score=o.score,
- nframes=o.nframes,
- host_info=tmp)
- add_host_info= False
- o= new_o
- yield o
- if ok:
- ii('decoding finished')
- return # Finish the decoding process
- # If this point is reached it means no decoder was available
- code=asr_common_pb2.DecodeResponse.Status.Code.ERR_NO_RECO_AVAILABLE
- raise DecoderException(code=code,
- msg=("no recogniser available for ASR"+
- " system '%s'")%system.info.id)
-
- # end decode
-
- # end DecoderMng
|