decoder_mng.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. __all__= ['DecoderMng','DecoderException']
  2. import threading,grpc,collections,traceback
  3. import logging as Log
  4. import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
  5. import mllp_grpc.asr_system_pb2 as asr_system_pb2
  6. import mllp_grpc.asr_common_pb2 as asr_common_pb2
  7. import google.protobuf.empty_pb2 as empty_pb2
  8. System= collections.namedtuple('System',['info','hosts'])
  9. SystemInfo= collections.namedtuple('SystemInfo',
  10. ['id','info','num_decoders',
  11. 'num_decoders_available'])
  12. HostInfo= collections.namedtuple('HostInfo',
  13. ['info','num_decoders',
  14. 'num_decoders_available',
  15. 'port','host'])
  16. def II(txt): Log.info('(DecoderMng) %s'%txt)
  17. class DecoderException(Exception):
  18. def __init__(self,code,msg):
  19. self.code= code
  20. self.msg= msg
  21. class LazzyIterator:
  22. def __init__(self,i):
  23. self._i= i
  24. self._cv= threading.Condition()
  25. self._ready= False
  26. def __call__(self):
  27. with self._cv:
  28. if not self._ready:
  29. self._cv.wait()
  30. for v in self._i:
  31. yield v
  32. def ready(self):
  33. with self._cv:
  34. self._ready= True
  35. self._cv.notify()
  36. # end LazzyIterator
  37. class DecoderMng:
  38. def __init__(self):
  39. self._lock= threading.Lock()
  40. self._sys_map= {} # System id -> int
  41. self._sys= [] # Llista de sistemes
  42. # NOTA! Pot generar excepcions.
  43. # Torna la info d'un sistema
  44. def __get_sys_info(self,host,port,ii=II):
  45. try:
  46. addr= '%s:%d'%(host,port)
  47. with grpc.insecure_channel(addr) as channel:
  48. stub= asr_system_pb2_grpc.ASRSystemStub(channel)
  49. return stub.GetSystemInfo(empty_pb2.Empty())
  50. except:
  51. ii('unable to connect to %s:%d'%(host,port))
  52. raise Exception('%s:%d is not a valid ASR System server'%(host,port))
  53. def __remove_host(self,sys_id,host,port,ii=II):
  54. with self._lock:
  55. ii(('trying to remove host %s:%d'+
  56. ' from system %s')%(host,port,
  57. self._sys[sys_id].info.id))
  58. try:
  59. del self._sys[sys_id].hosts[(host,port)]
  60. ii('host removed')
  61. except:
  62. pass
  63. # Donat un sys_id (un int) torna un iterador sobre parells:
  64. # connexió al hosts d'este sistema i info. Si algun host falla al
  65. # connectar-se o el seu identificador no correspon amb el que
  66. # deuria, el host és ignorat i eliminat de la llista de
  67. # hosts. Aquest iterador no genera excepcions.
  68. def __iter_sys_stubs(self,sys_id,ii=II):
  69. with self._lock:
  70. system= self._sys[sys_id]
  71. hosts= list(system.hosts.keys())
  72. for host,port in hosts:
  73. # Prova connexió i comprova id
  74. try:
  75. addr= '%s:%d'%(host,port)
  76. with grpc.insecure_channel(addr) as channel:
  77. stub= asr_system_pb2_grpc.ASRSystemStub(channel)
  78. info= stub.GetSystemInfo(empty_pb2.Empty())
  79. if info.info.id!=system.info.id:
  80. ii(('bad identifier for host %s:%d:'+
  81. ' %s != %s')%(host,port,info.info.id,
  82. system.info.id))
  83. self.__remove_host(sys_id,host,port,ii)
  84. elif info.info.tag!=system.info.tag:
  85. ii(('bad tag for host %s:%d:'+
  86. ' %s != %s')%(host,port,info.info.tag,
  87. system.info.tag))
  88. self.__remove_host(sys_id,host,port,ii)
  89. else:
  90. yield stub,info,(host,port)
  91. except grpc.RpcError:
  92. traceback.print_exc()
  93. ii('unable to connect to %s:%d'%(host,port))
  94. self.__remove_host(sys_id,host,port,ii)
  95. except GeneratorExit:
  96. pass
  97. except:
  98. traceback.print_exc()
  99. ii('unknown error')
  100. # end __iter_sys_stubs
  101. def __num_connexions(self,sys_id,ii=II):
  102. ret= 0
  103. for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
  104. ret+= 1
  105. return ret
  106. # NOTA! Pot generar excepcions.
  107. def add(self,host,port,ii=II):
  108. info= self.__get_sys_info(host,port,ii)
  109. with self._lock:
  110. sys_id= self._sys_map.get(info.info.id)
  111. numcon= 0 if sys_id is None else self.__num_connexions(sys_id)
  112. # Add host
  113. with self._lock:
  114. # Repetit a propòsit!!!
  115. sys_id= self._sys_map.get(info.info.id)
  116. if sys_id==None:
  117. sys_id= len(self._sys)
  118. self._sys_map[info.info.id]= sys_id
  119. sys= System(info=info.info,hosts={})
  120. self._sys.append(sys)
  121. elif self._sys[sys_id].info.tag!=info.info.tag:
  122. system= self._sys[sys_id]
  123. hosts= list(system.hosts.keys())
  124. if len(system.hosts)>0 and numcon>0:
  125. raise RuntimeError(
  126. 'tag mismatch %s != %s'%(self._sys[sys_id].info.tag,
  127. info.info.tag))
  128. else:
  129. self._sys[sys_id]= System(info=info.info,hosts={})
  130. key= (host,port)
  131. self._sys[sys_id].hosts[key]= True
  132. # end add
  133. def get_system_info(self,system_id,ii=II):
  134. # Obté el sys_id associat a sytem_id
  135. sys_id,system= -1,...
  136. with self._lock:
  137. for tmp_sys_id,tmp_system in enumerate(self._sys):
  138. if tmp_system.info.id==system_id:
  139. sys_id= tmp_sys_id
  140. system= tmp_system
  141. break
  142. if sys_id==-1: return False,...
  143. # Obté info
  144. num_decoders= 0
  145. num_decoders_available= 0
  146. for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
  147. if sinfo.enabled:
  148. num_decoders+= sinfo.num_decoders
  149. num_decoders_available+= sinfo.num_decoders_available
  150. if num_decoders>0:
  151. return True,SystemInfo(info=system.info,
  152. num_decoders=num_decoders,
  153. num_decoders_available=num_decoders_available,
  154. id=sys_id)
  155. else: return False,...
  156. # end get_system_info
  157. def get_systems_info(self,ii=II):
  158. # Llista de systemes
  159. with self._lock:
  160. nsys= len(self._sys)
  161. # Itera sobre els sistemes
  162. for sys_id in range(nsys):
  163. # Get info
  164. with self._lock:
  165. system= self._sys[sys_id]
  166. # Pregunta a tots els hosts
  167. num_decoders= 0
  168. num_decoders_available= 0
  169. for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
  170. if sinfo.enabled:
  171. num_decoders+= sinfo.num_decoders
  172. num_decoders_available+= sinfo.num_decoders_available
  173. # Return
  174. if num_decoders>0:
  175. yield SystemInfo(info=system.info,
  176. num_decoders=num_decoders,
  177. num_decoders_available=num_decoders_available,
  178. id=sys_id)
  179. # end get_systems_info
  180. def get_hosts_info(self,ii=II):
  181. # Llista de systemes
  182. with self._lock:
  183. nsys= len(self._sys)
  184. # Itera sobre els sistemes
  185. for sys_id in range(nsys):
  186. # Get info
  187. with self._lock:
  188. system= self._sys[sys_id]
  189. # Pregunta a tots els hosts
  190. for stub,sinfo,addr in self.__iter_sys_stubs(sys_id,ii):
  191. yield HostInfo(info=system.info,
  192. num_decoders=sinfo.num_decoders,
  193. num_decoders_available=sinfo.num_decoders_available,
  194. host=addr[0],
  195. port=addr[1])
  196. # end get_hosts_info
  197. # NOTA!!! Pot generar excepcions de tipus DecoderException o
  198. # Exception.
  199. def decode(self,request_iterator,ii):
  200. def ee_format(txt):
  201. code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT
  202. raise DecoderException(code=code,msg=txt)
  203. def ee_unk_system(sys_id):
  204. code=asr_common_pb2.DecodeResponse.Status.Code.ERR_UNKNOWN_SYSTEM
  205. raise DecoderException(code=code,msg='%d is not a valid ASR system'%sys_id)
  206. def transform_data(iterator):
  207. try:
  208. for o in iterator:
  209. if o.HasField('token'):
  210. yield asr_common_pb2.DataPackage(token=o.token)
  211. elif not o.HasField('data'):
  212. ee_format(('a system_id package was provided'+
  213. ' when a data package was expected'))
  214. else:
  215. yield asr_common_pb2.DataPackage(data=o.data)
  216. except grpc.RpcError:
  217. ii(('an expected error ocurred while reading input data,'+
  218. ' probably connexion was closed by client'))
  219. # Selecciona sistema
  220. p= next(request_iterator,None)
  221. if p is None: ee_format('no packages were sent for decoding')
  222. if not p.HasField('system_id'):
  223. ee_format('First package must contain a system identifier')
  224. try:
  225. with self._lock:
  226. system= self._sys[p.system_id]
  227. except:
  228. ee_unk_system(p.system_id)
  229. ii('starting decoding with system: %s'%system.info.id)
  230. # Find for the first available system and decode
  231. add_host_info= True
  232. input_data= LazzyIterator(transform_data(request_iterator))
  233. for stub,sinfo,addr in self.__iter_sys_stubs(p.system_id,ii):
  234. if sinfo.enabled and sinfo.num_decoders_available>0:
  235. ok= False
  236. for o in stub.Decode(input_data()):
  237. if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
  238. ii('connexion stablished with %s:%d'%(addr[0],addr[1]))
  239. input_data.ready()
  240. ok= True
  241. continue
  242. elif o.status.code==asr_common_pb2.DecodeResponse.Status.ERR_NO_RECO_AVAILABLE:
  243. break # Abort decoding
  244. if add_host_info:
  245. tmp= asr_common_pb2.DecodeResponse.HostInfo(host=addr[0],port=addr[1])
  246. new_o= asr_common_pb2.DecodeResponse(status=o.status,
  247. hyp_novar=o.hyp_novar,
  248. hyp_var=o.hyp_var,
  249. eos=o.eos,
  250. score=o.score,
  251. nframes=o.nframes,
  252. host_info=tmp)
  253. add_host_info= False
  254. o= new_o
  255. yield o
  256. if ok:
  257. ii('decoding finished')
  258. return # Finish the decoding process
  259. # If this point is reached it means no decoder was available
  260. code=asr_common_pb2.DecodeResponse.Status.Code.ERR_NO_RECO_AVAILABLE
  261. raise DecoderException(code=code,
  262. msg=("no recogniser available for ASR"+
  263. " system '%s'")%system.info.id)
  264. # end decode
  265. # end DecoderMng