test_asr_server.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import logging as Log
  2. import grpc,argparse,time,sys,threading
  3. import mllp_grpc.asr_pb2_grpc as asr_pb2_grpc
  4. import mllp_grpc.asr_pb2 as asr_pb2
  5. import mllp_grpc.asr_common_pb2 as asr_common_pb2
  6. import google.protobuf.empty_pb2 as empty_pb2
  7. from pydub import AudioSegment
  8. from math import ceil
  9. def get_data(args,sys_id):
  10. Log.info("Loading '%s'"%args.wav)
  11. sound= AudioSegment.from_wav(args.wav)
  12. if (sound.frame_rate != 16000):
  13. sound= sound.set_frame_rate(16000)
  14. nbuffers= ceil(len(sound)/args.bsize)
  15. Log.info(('c: %d r: %d rate: %d len: %d nbuf:'+
  16. ' %d bsize: %d')%(sound.channels,
  17. sound.sample_width,
  18. sound.frame_rate,
  19. len(sound),
  20. nbuffers,
  21. args.bsize))
  22. sended= 0
  23. Log.info('Sending data')
  24. yield asr_pb2.DecodeRequest(system_id=sys_id)
  25. for n in range(0,nbuffers):
  26. buf= sound[n*args.bsize:(n+1)*args.bsize]
  27. yield asr_pb2.DecodeRequest(data=buf.raw_data)
  28. sended+= args.bsize
  29. if sended>=1000:
  30. sended-= 1000
  31. time.sleep(1)
  32. Log.info('Transmission completed')
  33. def parse_cmdline():
  34. p= argparse.ArgumentParser(description='ASR GRPC client')
  35. p.add_argument('host',type=str,help='server host')
  36. p.add_argument('port',type=int,help='server port')
  37. p.add_argument('wav',type=str,help='input 1 channel wav at 16KHz')
  38. p.add_argument('--bsize',type=int,default=250,
  39. help='buffer size')
  40. args= p.parse_args()
  41. return args
  42. def run(args):
  43. addr= '%s:%d'%(args.host,args.port)
  44. with grpc.insecure_channel(addr) as channel:
  45. stub= asr_pb2_grpc.ASRStub(channel)
  46. for info in stub.GetSystemsInfo(empty_pb2.Empty()):
  47. print(info)
  48. print_host= True
  49. for o in stub.Decode(get_data(args,1)):
  50. if print_host:
  51. print_host= False
  52. print(o.host_info)
  53. if o.status.code!=asr_common_pb2.DecodeResponse.Status.Code.OK:
  54. print(o)
  55. elif o.hyp_novar!='':
  56. sys.stdout.write(o.hyp_novar+' ')
  57. sys.stdout.flush()
  58. sys.stdout.write('\n')
  59. sys.stdout.flush()
  60. if __name__=='__main__':
  61. Log.basicConfig(level=Log.INFO)
  62. args= parse_cmdline()
  63. run(args)