12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import logging as Log
- import grpc,argparse,time,sys,threading
- import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
- import mllp_grpc.asr_system_pb2 as asr_system_pb2
- import mllp_grpc.asr_common_pb2 as asr_common_pb2
- import google.protobuf.empty_pb2 as empty_pb2
- from pydub import AudioSegment
- from math import ceil
- class LazzyIterator:
- def __init__(self,i):
- self._i= i
- self._cv= threading.Condition()
- self._ready= False
- def __call__(self):
- with self._cv:
- if not self._ready:
- self._cv.wait()
- for v in self._i:
- yield v
-
- def ready(self):
- with self._cv:
- self._ready= True
- self._cv.notify()
-
- # end LazzyIterator
- def get_data(args):
- Log.info("Loading '%s'"%args.wav)
- sound= AudioSegment.from_wav(args.wav)
-
- if (sound.frame_rate != 16000):
- sound= sound.set_frame_rate(16000)
-
- nbuffers= ceil(len(sound)/args.bsize)
- Log.info(('c: %d r: %d rate: %d len: %d nbuf:'+
- ' %d bsize: %d')%(sound.channels,
- sound.sample_width,
- sound.frame_rate,
- len(sound),
- nbuffers,
- args.bsize))
- sended= 0
- Log.info('Sending data')
- for n in range(0,nbuffers):
- buf= sound[n*args.bsize:(n+1)*args.bsize]
- yield asr_common_pb2.DataPackage(data=buf.raw_data)
- sended+= args.bsize
- if sended>=1000:
- sended-= 1000
- time.sleep(1)
- Log.info('Transmission completed')
-
- def parse_cmdline():
- p= argparse.ArgumentParser(description='ASR GRPC client')
- p.add_argument('host',type=str,help='server host')
- p.add_argument('port',type=int,help='server port')
- p.add_argument('wav',type=str,help='input 1 channel wav at 16KHz')
- p.add_argument('--bsize',type=int,default=250,
- help='buffer size')
- args= p.parse_args()
- return args
- def run(args):
- addr= '%s:%d'%(args.host,args.port)
- with grpc.insecure_channel(addr) as channel:
- stub= asr_system_pb2_grpc.ASRSystemStub(channel)
- info= stub.GetSystemInfo(empty_pb2.Empty())
- print(info)
- data= LazzyIterator(get_data(args))
- for o in stub.Decode(data()):
- if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
- data.ready()
- continue
- if o.hyp_novar!='':
- sys.stdout.write(o.hyp_novar+' ')
- sys.stdout.flush()
- if o.eos:
- sys.stdout.write('\n')
- sys.stdout.flush()
- print('------------------------------------------------')
-
- if __name__=='__main__':
- Log.basicConfig(level=Log.INFO)
- args= parse_cmdline()
- run(args)
|