test_asr_system.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import logging as Log
  2. import grpc,argparse,time,sys,threading
  3. import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
  4. import mllp_grpc.asr_system_pb2 as asr_system_pb2
  5. import mllp_grpc.asr_common_pb2 as asr_common_pb2
  6. import google.protobuf.empty_pb2 as empty_pb2
  7. from pydub import AudioSegment
  8. from math import ceil
  9. class LazzyIterator:
  10. def __init__(self,i):
  11. self._i= i
  12. self._cv= threading.Condition()
  13. self._ready= False
  14. def __call__(self):
  15. with self._cv:
  16. if not self._ready:
  17. self._cv.wait()
  18. for v in self._i:
  19. yield v
  20. def ready(self):
  21. with self._cv:
  22. self._ready= True
  23. self._cv.notify()
  24. # end LazzyIterator
  25. def get_data(args):
  26. Log.info("Loading '%s'"%args.wav)
  27. sound= AudioSegment.from_wav(args.wav)
  28. if (sound.frame_rate != 16000):
  29. sound= sound.set_frame_rate(16000)
  30. nbuffers= ceil(len(sound)/args.bsize)
  31. Log.info(('c: %d r: %d rate: %d len: %d nbuf:'+
  32. ' %d bsize: %d')%(sound.channels,
  33. sound.sample_width,
  34. sound.frame_rate,
  35. len(sound),
  36. nbuffers,
  37. args.bsize))
  38. sended= 0
  39. Log.info('Sending data')
  40. for n in range(0,nbuffers):
  41. buf= sound[n*args.bsize:(n+1)*args.bsize]
  42. yield asr_common_pb2.DataPackage(data=buf.raw_data)
  43. sended+= args.bsize
  44. if sended>=1000:
  45. sended-= 1000
  46. time.sleep(1)
  47. Log.info('Transmission completed')
  48. def parse_cmdline():
  49. p= argparse.ArgumentParser(description='ASR GRPC client')
  50. p.add_argument('host',type=str,help='server host')
  51. p.add_argument('port',type=int,help='server port')
  52. p.add_argument('wav',type=str,help='input 1 channel wav at 16KHz')
  53. p.add_argument('--bsize',type=int,default=250,
  54. help='buffer size')
  55. args= p.parse_args()
  56. return args
  57. def run(args):
  58. addr= '%s:%d'%(args.host,args.port)
  59. with grpc.insecure_channel(addr) as channel:
  60. stub= asr_system_pb2_grpc.ASRSystemStub(channel)
  61. info= stub.GetSystemInfo(empty_pb2.Empty())
  62. print(info)
  63. data= LazzyIterator(get_data(args))
  64. for o in stub.Decode(data()):
  65. if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
  66. data.ready()
  67. continue
  68. if o.hyp_novar!='':
  69. sys.stdout.write(o.hyp_novar+' ')
  70. sys.stdout.flush()
  71. if o.eos:
  72. sys.stdout.write('\n')
  73. sys.stdout.flush()
  74. print('------------------------------------------------')
  75. if __name__=='__main__':
  76. Log.basicConfig(level=Log.INFO)
  77. args= parse_cmdline()
  78. run(args)