Ver código fonte

Code to test ASR/SMT server/systems and foo audio/text

jjorgeDSIC 1 mês atrás
pai
commit
88e0701a2e

BIN
code/audio.obama.wav


+ 116 - 0
code/server_lib/asr_servicer.py

@@ -0,0 +1,116 @@
+__all__= ['ASRServicer','init']
+
+import grpc,threading,traceback
+import logging as Log
+
+import mllp_grpc.asr_pb2_grpc as asr_pb2_grpc
+import mllp_grpc.asr_pb2 as asr_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+import server_lib.decoder_multi as decoder_multi
+from server_lib.decoder_mng import DecoderException
+
+def ii_(func,txt,context):
+    Log.info('[%s] (%s) %s'%(context.peer(),func,txt))
+    
+class ASRServicer(asr_pb2_grpc.ASRServicer):
+
+    def __init__(self,mng):
+        self._mng= mng
+
+    def AddSystem(self,request,context):
+        def ii(txt): ii_('AddSystem',txt,context)
+        try:
+            self._mng.add(request.host,request.port,ii)
+            ii('added %s:%d'%(request.host,request.port))
+            code= asr_pb2.AddSystemResponse.Code.OK
+            return asr_pb2.AddSystemResponse(code=code)
+        except Exception as e:
+            msg= 'error adding system %s:%d: %s'%(request.host,
+                                                  request.port,
+                                                  str(e))
+            ii(msg)
+            code= asr_pb2.AddSystemResponse.Code.OK
+            return asr_pb2.AddSystemResponse(code=code,
+                                             details=msg)
+
+    def GetSystemInfoById(self,request,context):
+        def ii(txt): ii_('GetSystemInfoById',txt,context)
+        ii("retrieving ASR system information for %s"%request.system_id)
+        ok,info= self._mng.get_system_info(request.system_id,ii)
+        if ok:
+            return asr_pb2.GetSystemInfoByIdResponse(
+                code=asr_pb2.GetSystemInfoByIdResponse.Code.OK,
+                info=info.info,
+                num_decoders=info.num_decoders,
+                num_decoders_available=info.num_decoders_available,
+                id=info.id)
+        else:
+            return asr_pb2.GetSystemInfoByIdResponse(
+                code=asr_pb2.GetSystemInfoByIdResponse.Code.ERR)
+        
+    def GetSystemsInfo(self,request,context):
+        def ii(txt): ii_('GetSystemsInfo',txt,context)
+        ii('retrieving ASR systems information')
+        for v in self._mng.get_systems_info(ii):
+            yield asr_pb2.GetSystemsInfoResponse(
+                info=v.info,
+                num_decoders=v.num_decoders,
+                num_decoders_available=v.num_decoders_available,
+                id=v.id)
+
+    def GetHostsInfo(self,request,context):
+        def ii(txt): ii_('GetHostsInfo',txt,context)
+        ii('retrieving hosts information')
+        for v in self._mng.get_hosts_info(ii):
+            yield asr_pb2.GetHostsInfoResponse(
+                info=v.info,
+                num_decoders=v.num_decoders,
+                num_decoders_available=v.num_decoders_available,
+                host=v.host,
+                port=v.port)
+            
+    def Decode(self,request_iterator,context):
+        def ii(txt): ii_('Decode',txt,context)
+        try:
+            for o in self._mng.decode(request_iterator,ii):
+                yield o
+        except DecoderException as e:
+            ii(e.msg)
+            status= asr_common_pb2.DecodeResponse.Status(code=e.code,
+                                                         details=e.msg)
+            yield asr_common_pb2.DecodeResponse(status=status)
+        except BaseException as e:
+            traceback.print_exc()
+            ii(str(e))
+            code= asr_common_pb2.DecodeResponse.Status.Code.ERR_RECO
+            status= asr_common_pb2.DecodeResponse.Status(code=code,
+                                                         details=str(e))
+            yield asr_common_pb2.DecodeResponse(status=status)
+        ii('Done!')
+
+    def DecodeMulti(self,request_iterator,context):
+        def ii(txt): ii_('DecodeMulti',txt,context)
+        try:
+            for o in decoder_multi.decode(self._mng,request_iterator,ii):
+                yield o
+        except DecoderException as e:
+            ii(e.msg)
+            status= asr_common_pb2.DecodeResponse.Status(code=e.code,
+                                                         details=e.msg)
+            res= asr_common_pb2.DecodeResponse(status=status)
+            yield asr_pb2.DecodeMultiResponse(res=res)
+        except BaseException as e:
+            traceback.print_exc()
+            ii(str(e))
+            code= asr_common_pb2.DecodeResponse.Status.Code.ERR_RECO
+            status= asr_common_pb2.DecodeResponse.Status(code=code,
+                                                         details=str(e))
+            res= asr_common_pb2.DecodeResponse(status=status)
+            yield asr_pb2.DecodeMultiResponse(res=res)
+        ii('Done!')
+        
+def init(server,mng):
+    asr_pb2_grpc.add_ASRServicer_to_server(
+        ASRServicer(mng),server)
+    

+ 423 - 0
code/server_lib/asr_system.py

@@ -0,0 +1,423 @@
+__all__= ['load']
+
+import sys,threading,collections,time
+import TLK
+import lmodel,oepmodel,segmodel
+
+import logging as Log
+
+Hyp= collections.namedtuple('Hyp',['var','novar','err',
+                                   'score','nframes',
+                                   'eos'])
+
+class Recogniser:
+
+    def __init__(self,models,params,
+#                 uppercase_tbl,uppercase_tbl_bi,
+                 sil_word,sil_sym,fea_freq,mng):
+        args= {
+            'ifreq'  : 16000,
+            'bsize'  : 400,
+            'amodel' : models['amodel'],
+            'lm'     : models['lm'],
+            'params' : params,
+            'output' : self,
+        }
+
+        self._fea_freq= fea_freq
+        if sil_word!=None:
+            args['sil_word']= sil_word
+            if sil_sym!=None:
+                args['sil_sym']= sil_sym
+        if models['dnn']!=None: args['dnn']= models['dnn']
+        if models['dlm']!=None: args['dlm']= models['dlm']
+        if models['mustd']!=None: args['mustd']= models['mustd']
+        if models['oep-factory']!=None:
+            oep= models['oep-factory'](models['oep-model'])
+            args['oep']= oep
+        self.reco= TLK.ORecogniser(**args)
+        if models['seg-factory']!=None:
+            self.seg= models['seg-factory'](models['seg-model'])
+        else: self.seg= None
+        self._mng= mng
+        self._cv= threading.Condition()
+
+    def reset(self):
+        self._err= False
+        self._hyps= []
+        self._prev= None
+        if self.seg is not None:
+            self.seg.reset()
+        # Per a gestionar tokens insertats el que vaig a fer es tindre
+        # una llista de llistes. En la primera llista estaran els
+        # tokens del segment que s'està processant actualment i en
+        # l'última els tokens de l'últim segment. Cada token serà una
+        # tupla paraula i número de frames on es va insertar.
+        self._itokens= [[]]
+        self._inframes= 0.0 # Numero de frames del segment actual
+                            # insertats en el reconeixedor
+
+    def __register_itoken_nosec(self,token):
+        token= token.strip()
+        if token=='': return
+        nframes= int(self._inframes) if self._inframes>0.0 else None
+        self._itokens[-1].append((token,nframes))
+
+    def __register_itoken(self,token):
+        try:
+            self.__register_itoken_nosec(token)
+        except:
+            pass
+    
+    def __register_itoken_eos(self):
+        if self._inframes>0.0:
+            self._itokens.append([])
+            self._inframes= 0.0
+
+    def __merge_itokens_nosec(self,code,hyp):
+        # Comprovacions inicials i eliminació
+        itokens= self._itokens[0]
+        if code==TLK.OUT_RES: self._itokens.pop(0)
+        if itokens==[]: return hyp
+        # Inserta
+        hyp,pos_var= hyp[0],hyp[1]
+        pref,novar,var= [],list(hyp[:pos_var]),list(hyp[pos_var:])
+        # -> Inserta inici segment
+        while len(itokens)>0 and itokens[0][1] is None:
+            tok= itokens.pop(0)
+            pref.append((tok[0],0,0))
+        # --> Processa novar
+        if len(itokens)>0 and len(novar)>0:
+            i,new_novar= 0,[]
+            while len(itokens)>0 and i<len(novar):
+                w,b,e= novar[i]
+                epos= b+e
+                while len(itokens)>0 and itokens[0][1]<epos:
+                    tok= itokens.pop(0)
+                    new_novar.append((tok[0],b,0))
+                new_novar.append((w,b,e))
+                i+= 1
+            while i<len(novar):
+                new_novar.append(novar[i])
+                i+= 1
+        else: new_novar= novar
+        # --> Processa var
+        if len(itokens)>0 and len(var)>0:
+            i,j,new_var= 0,0,[]
+            while len(itokens)>0 and i<len(var):
+                w,b,e= var[i]
+                epos= b+e
+                while j<len(itokens) and itokens[j][1]<epos:
+                    new_var.append((itokens[j][0],b,0))
+                    j+= 1
+                new_var.append((w,b,e))
+                i+= 1
+            while i<len(var):
+                new_var.append(var[i])
+                i+= 1
+        else: new_var= var
+        # --> Afegeix pendents si code==TLK.OUT_RES
+        if code==TLK.OUT_RES and len(itokens)>0:
+            assert var==[]
+            pos= new_novar[-1][1]+new_novar[-1][2] if len(new_novar)>0 else -1
+            for tok in itokens:
+                tmp_pos= pos if pos!=-1 else tok[1]
+                new_novar.append((tok[0],tmp_pos,0))
+        return pref+new_novar+new_var,len(pref)+len(new_novar)
+    # end __merge_itokens_nosec
+
+    def __merge_itokens(self,code,hyp):
+        try:
+            return self.__merge_itokens_nosec(code,hyp)
+        except:
+            return hyp
+    
+    # None denotes end of segment
+    def feed(self,data=None):
+        if data==None:
+            self.reco.feed()
+        elif not self._err:
+            if type(data)==str:
+                self.__register_itoken(data)
+            elif len(data)==0:
+                self.reco.split()
+                self.__register_itoken_eos()
+            else:
+                self._inframes+= (len(data)/32000)*self._fea_freq
+                self.reco.feed(data)
+
+    def __write_hyp(self,novar,var,score,nframes,err,eos):
+        h= Hyp(novar=novar,var=var,err=err,
+               score=score,nframes=nframes,
+               eos=eos)
+        with self._cv:
+            self._hyps.append(h)
+            self._cv.notify_all()
+        
+    def process_out(self,code,hyp,stats):
+        #def totxt(rec,last_word):
+        #    aux= []
+        #    for x in rec:
+        #        w= x[0].lower()
+        #        aux.append(toupper(w,last_word))
+        #        last_word= w
+        #    ret= ' '.join(aux)
+        #    ret= ret.replace('[hesitation]','').replace('<unk>','')
+        #    return ret
+        def totxt(rec):
+            ret= ' '.join([x[0] for x in rec])
+            ret= ret.replace('[hesitation]','').replace('<unk>','')
+            return ret
+        hyp,pos_var= hyp[0],hyp[1]
+        novar= hyp[:pos_var]
+        #new_last_word= novar[-1][0] if len(novar)>0 else self.last_word
+        #novar= totxt(novar,self.last_word)
+        #var= totxt(hyp[pos_var:],new_last_word)
+        #self.last_word= new_last_word
+        novar= totxt(novar)
+        var= totxt(hyp[pos_var:])
+        if novar=='' and var==self._prev and code==TLK.OUT_HYP: return
+        self._prev= var
+        score,nframes= stats
+        self.__write_hyp(novar,var,score,nframes,False,
+                         code==TLK.OUT_RES)
+        
+    # Output method
+    def write(self,code,hyp,stats):
+        hyp= self.__merge_itokens(code,hyp)
+        if code==TLK.OUT_ERR:
+            txt= self._prev if self._prev!=None else ''
+            self.__write_hyp(txt,'',0,0,False,False)
+            self._err= True
+            return
+        if code==TLK.OUT_END:
+            self.__write_hyp(None,None,0,0,self._err,False)
+            return
+        if self.seg is not None:
+            hyp= self.seg(hyp)
+            if code==TLK.OUT_RES:
+                hyp0,pos_var0= hyp
+                hyp1= self.seg.eos()
+                if hyp1 is not None:
+                    hyp1,pos_var1= hyp1
+                    hyp= hyp0[:pos_var0]+hyp1[:pos_var1]
+                    hyp= hyp,len(hyp)
+                self.seg.reset()
+        self.process_out(code,hyp,stats)
+
+    @property
+    def output(self):
+        end= False
+        while not end:
+            with self._cv:
+                self._cv.wait_for(lambda: len(self._hyps)>0)
+                ret= self._hyps.pop(0)
+                end= ret.novar is None
+            yield ret
+    
+# end Recogniser
+
+class RecogniserManager:
+
+    def __init__(self,name,tag,date,lang,
+                 models,params,nreco,
+                 sil_word,sil_sym,fea_freq):
+        self.v= []
+        self.vv= [] # Reference copy
+        #self.cmllr= params.cmllr_enabled
+        for n in range(0,nreco):
+            reco= Recogniser(models,params,
+                             sil_word,sil_sym,
+                             fea_freq,self)
+            self.v.append(reco)
+            self.vv.append(reco)
+        self.name= name
+        self.tag= tag
+        self.date= date
+        self.lang= lang
+        self._lock= threading.Lock()
+        self._models= models
+        self._enabled= True
+
+    def __len__(self):
+        return len(self.vv)
+    
+    def get_reco(self):
+        with self._lock:
+            if self.v==[]: return None
+            if not self._enabled: return None
+            ret= self.v.pop()
+            ret.reset()
+            #ret.reset_cmllr()
+            Log.info(('Recogniser taken from %s '+
+                      '(available: %d of %d)')%(self.name,
+                                                len(self.v),
+                                                len(self.vv)))
+            return ret
+    
+    def append(self,reco):
+        with self._lock:
+            self.v.append(reco)
+            Log.info(('Recogniser from %s realeased'+
+                      '(available: %d of %d)')%(self.name,
+                                                len(self.v),
+                                                len(self.vv)))
+    
+    def set_enabled(self,value):
+        with self._lock:
+            self._enabled= value
+    
+    @property
+    def num_recos_available(self):
+        with self._lock:
+            return len(self.v)
+
+    @property
+    def enabled(self):
+        with self._lock:
+            return self._enabled
+
+# end RecogniserManager
+
+def prepare_params(cfg):
+    params= TLK.OParameters()
+    aux= cfg.get('hp')
+    if aux!=None: params.hp= int(aux)
+    aux= cfg.get('hp_min')
+    if aux!=None: params.hp_min= int(aux)
+    aux= cfg.get('wep')
+    if aux!=None: params.wep= float(aux)
+    aux= cfg.get('beam')
+    if aux!=None: params.beam= float(aux)
+    aux= cfg.get('dynormthr')
+    if aux!=None: params.dynormthr= int(aux)
+    aux= cfg.get('meannorm')
+    if aux!=None: params.meannorm= bool(aux)
+    aux= cfg.get('cmllr')
+    if aux!=None: params.cmllr_enabled= bool(aux)
+    aux= cfg.get('cmllr_nframes_step0')
+    if aux!=None: params.cmllr_nframes_step0= int(aux)
+    aux= cfg.get('cmllr_nframes')
+    if aux!=None: params.cmllr_nframes= int(aux)
+    aux= cfg.get('cmllr_niters')
+    if aux!=None: params.cmllr_niters= int(aux)
+    aux= cfg.get('numceps')
+    if aux!=None: params.numceps= int(aux)
+    aux= cfg.get('numchans')
+    if aux!=None: params.numchans= int(aux)
+    aux= cfg.get('accwindow')
+    if aux!=None: params.accwindow= int(aux)
+    aux= cfg.get('deltawindow')
+    if aux!=None: params.deltawindow= int(aux)
+    aux= cfg.get('hp_lm')
+    if aux!=None: params.hp_lm= int(aux)
+    aux= cfg.get('order')
+    if aux!=None: params.order= int(aux)
+    aux= cfg.get('ftype')
+    if aux!=None: params.ftype= str(aux)
+    aux= cfg.get('mustdnorm')
+    if aux!=None: params.mustdnorm= bool(aux)
+    aux= cfg.get('gsf')
+    gsf= float(aux) if aux!=None else 1.0
+    params.gsf= gsf
+    aux= cfg.get('wip')
+    wip= float(aux) if aux!=None else 0.0
+    params.wip= wip
+    params.amla= cfg.get('amla-fs')!=None
+    params.sil= cfg.get('sil_sym','SP')
+    
+    # IMPORTANT !!!!! Deshabilita l'heurístic de segmentació
+    params.sil_length= 20000
+    params.sil_thr= 0.0
+
+    return params
+# end prepare_params
+
+def load_models(cfg,name,params,sil_word):
+    Log.info("Loading '%s' models..."%name)
+    mustd= cfg.get('mustd')
+    dnn= cfg.get('dnn')
+    if params.cmllr_enabled:
+        sys.exit('CMLLR NOT SUPPORTED !!!!')
+        target= cfg.get('target')
+        dnn_cmllr= cfg.get('dnn_cmllr')
+    else:
+        target= dnn_cmllr= None
+        lexicon= cfg.get('lexicon')
+        
+    amodel= TLK.AModel(cfg['amodel'])
+    if lexicon!=None:
+        lex= TLK.Lexicon(lexicon, syms=amodel.syms)
+        lm= TLK.LM()
+        lm.load(cfg['lm'],lexicon=lex)
+    else:
+        try:
+            lm= TLK.SearchGraph(cfg['lm'], syms=amodel.syms)
+        except:
+            lm= TLK.StaticLookaheadTables(cfg['lm'], syms=amodel.syms)
+        dlm= cfg.get('dlm')
+    if dlm!=None:
+        dlm= lmodel.load_model(dlm)(lm.words)
+    oep= cfg.get('oep')
+    if oep!=None:
+        step= cfg.get('oep-step')
+        if step==None:
+            sys.exit('oep-step not defined')
+        step= int(step)
+        lookahead= cfg.get('oep-lookahead')
+        if lookahead==None:
+            sys.exit('oep-lookahead not defined')
+        lookahead= int(lookahead)
+        priors_fn= cfg.get('oep-priors')
+        amla_fs= cfg.get('amla-fs')
+        tmp= oepmodel.load_model(oep,step,lookahead,
+                                 priors_fn,amla_fs)
+        oep_factory,oep_model= tmp
+    else: oep_factory= oep_model= None
+    seg= cfg.get('seg')
+    if seg!=None:
+        tmp= segmodel.load_model(seg,sil_word=sil_word)
+        seg_factory,seg_model= tmp
+    else: seg_factory= seg_model= None
+    aux= cfg.get('gsf')
+    gsf= float(aux) if aux!=None else 1.0
+    aux= cfg.get('wip')
+    wip= float(aux) if aux!=None else 0.0
+    models= {}
+    models["amodel"]=amodel
+    models["lm"]=lm
+    models["dnn"]=dnn
+    models["gsf"]=gsf
+    models["wip"]=wip
+    models['dlm']= dlm
+    models['oep-factory']= oep_factory
+    models['oep-model']= oep_model
+    models['mustd']= mustd
+    models['seg-factory']= seg_factory
+    models['seg-model']= seg_model
+    return models
+
+# end load_models
+    
+def load_system(conf):
+    sil_word= conf.get('sil_word')
+    sil_sym= conf.get('sil_sym')
+    fea_freq= conf.get('fea_freq',100)
+    params= prepare_params(conf)
+    name= conf['id']
+    tag= conf.get('tag',name)
+    date= tuple(conf.get('date',[1,1,1971]))
+    models= load_models(conf,tag,params,sil_word)
+    Log.info("Creating '%s' recognisers..."%name)
+    return RecogniserManager(name,tag,date,conf['lang'],
+                             models,params,conf['nreco'],
+                             sil_word,sil_sym,fea_freq)
+#                             uppercase_tbl,uppercase_tbl_bi,
+    # Uppercase-table (Deshabilitat de moment)
+    #aux= cfg.get('uppercase-table')
+    #uppercase_tbl= load_uppercase_table(aux) if aux!=None else None
+    #aux= cfg.get('uppercase-table-bi')
+    #uppercase_tbl_bi= load_uppercase_table_bi(aux) if aux!=None else None
+    
+# end create_recognisers
+    

+ 108 - 0
code/server_lib/asr_system_servicer.py

@@ -0,0 +1,108 @@
+__all__= ['ASRSystemServicer','init']
+
+import grpc,threading,traceback
+import logging as Log
+
+import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
+import mllp_grpc.asr_system_pb2 as asr_system_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+def ii_(func,txt,context):
+    Log.info('[%s] (%s) %s'%(context.peer(),func,txt))
+
+class ASRSystemServicer(asr_system_pb2_grpc.ASRSystemServicer):
+
+    def __init__(self,system):
+        self._sys= system
+
+    def GetSystemInfo(self,request,context):
+        def ii(txt): ii_('GetSystemInfo',txt,context)
+        ii("Retrieving info from ASR system")
+
+        langs=[asr_common_pb2.SystemInfo.Lang(code=x[0],
+                                              text=x[1])
+               for x in self._sys.lang]
+        date= asr_common_pb2.SystemInfo.Date(day=self._sys.date[0],
+                                             month=self._sys.date[1],
+                                             year=self._sys.date[2])
+        info= asr_common_pb2.SystemInfo(id=self._sys.name,
+                                        langs=langs,
+                                        tag=self._sys.tag,
+                                        date=date)
+        num_recos= self._sys.num_recos_available
+        ret= asr_system_pb2.GetSystemInfoResponse(info=info,
+                                                  num_decoders=len(self._sys),
+                                                  num_decoders_available=num_recos,
+                                                  enabled=self._sys.enabled)
+        return ret
+        
+    def Decode(self,request_iterator,context):
+        def ii(txt): ii_('Decode',txt,context)
+        def ee(code,txt):
+            status= asr_common_pb2.DecodeResponse.Status(code=code,
+                                                         details=txt)
+            return asr_common_pb2.DecodeResponse(status=status)
+        # Check system selected
+        # Check recogniser available
+        reco= self._sys.get_reco()
+        if reco is None:
+            return ee(asr_common_pb2.DecodeResponse.Status.Code.ERR_NO_RECO_AVAILABLE,
+                      "no recogniser available for ASR system '%s'"%self._sys.name)
+        code= asr_common_pb2.DecodeResponse.Status.Code.READY
+        status= asr_common_pb2.DecodeResponse.Status(code=code)
+        yield asr_common_pb2.DecodeResponse(status=status)
+        ii('starting decoding with system: %s'%self._sys.name)
+        # Create a thread to feed data
+        def feed_data():
+            try:
+                for d in request_iterator:
+                    if d.HasField('token'):
+                        reco.feed(d.token)
+                    else:
+                        reco.feed(d.data)
+            except Exception as e:
+                traceback.print_exc()
+                ii('[feed data] client closed connection, recognition stoped')
+            reco.feed()
+        t= threading.Thread(target=feed_data)
+        t.start()
+        # Process output
+        try:
+            for o in reco.output:
+                if o.novar is None: # EOS or ERR
+                    if o.err==True: # Err
+                        yield ee(asr_common_pb2.DecodeResponse.Status.Code.ERR_RECO,
+                                 "an unexpected error ocurred during recognition")
+                    else: pass # nothing on EOS
+                else:
+                    val= asr_common_pb2.DecodeResponse.Status.Code.OK
+                    status= asr_common_pb2.DecodeResponse.Status(code=val)
+                    yield asr_common_pb2.DecodeResponse(status=status,
+                                                        hyp_novar=o.novar,
+                                                        hyp_var=o.var,
+                                                        score=o.score,
+                                                        nframes=o.nframes,
+                                                        eos=o.eos)
+        except:
+            traceback.print_exc()
+            ii('[read output] client closed connection, recognition stoped')
+            for o in reco.output:
+                pass
+        # Wait thread (should be stopped at this point) and return reco
+        t.join()
+        self._sys.append(reco)
+        ii('decoding finished')
+    # end Decode
+
+    def SetEnabled(self,request,context):
+        def ii(txt): ii_('SetEnabled',txt,context)
+        ii("system %s"%('enabled' if request.value else 'disabled'))
+        self._sys.set_enabled(request.value)
+        ret= empty_pb2.Empty()
+        return ret
+        
+def init(server,system):
+    asr_system_pb2_grpc.add_ASRSystemServicer_to_server(
+        ASRSystemServicer(system),server)
+    

+ 291 - 0
code/server_lib/decoder_mng.py

@@ -0,0 +1,291 @@
+__all__= ['DecoderMng','DecoderException']
+
+import threading,grpc,collections,traceback
+
+import logging as Log
+
+import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
+import mllp_grpc.asr_system_pb2 as asr_system_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+System= collections.namedtuple('System',['info','hosts'])
+
+SystemInfo= collections.namedtuple('SystemInfo',
+                                   ['id','info','num_decoders',
+                                    'num_decoders_available'])
+HostInfo= collections.namedtuple('HostInfo',
+                                 ['info','num_decoders',
+                                  'num_decoders_available',
+                                  'port','host'])
+
+def II(txt): Log.info('(DecoderMng) %s'%txt)
+
+class DecoderException(Exception):
+
+    def __init__(self,code,msg):
+        self.code= code
+        self.msg= msg
+        
+class LazzyIterator:
+    
+    def __init__(self,i):
+        self._i= i
+        self._cv= threading.Condition()
+        self._ready= False
+        
+    def __call__(self):
+        with self._cv:
+            if not self._ready:
+                self._cv.wait()
+        for v in self._i:
+            yield v
+            
+    def ready(self):
+        with self._cv:
+            self._ready= True
+            self._cv.notify()
+
+# end LazzyIterator
+
+class DecoderMng:
+
+    def __init__(self):
+        self._lock= threading.Lock()
+        self._sys_map= {} # System id -> int
+        self._sys= [] # Llista de sistemes
+
+    # NOTA! Pot generar excepcions.
+    # Torna la info d'un sistema
+    def __get_sys_info(self,host,port,ii=II):
+        try:
+            addr= '%s:%d'%(host,port)
+            with grpc.insecure_channel(addr) as channel:
+                stub= asr_system_pb2_grpc.ASRSystemStub(channel)
+                return stub.GetSystemInfo(empty_pb2.Empty())
+        except:
+            ii('unable to connect to %s:%d'%(host,port))
+            raise Exception('%s:%d is not a valid ASR System server'%(host,port))
+
+    def __remove_host(self,sys_id,host,port,ii=II):
+        with self._lock:
+            ii(('trying to remove host %s:%d'+
+                ' from system %s')%(host,port,
+                                    self._sys[sys_id].info.id))
+            try:
+                del self._sys[sys_id].hosts[(host,port)]
+                ii('host removed')
+            except:
+                pass
+    
+    # Donat un sys_id (un int) torna un iterador sobre parells:
+    # connexió al hosts d'este sistema i info. Si algun host falla al
+    # connectar-se o el seu identificador no correspon amb el que
+    # deuria, el host és ignorat i eliminat de la llista de
+    # hosts. Aquest iterador no genera excepcions.
+    def __iter_sys_stubs(self,sys_id,ii=II):
+        with self._lock:
+            system= self._sys[sys_id]
+            hosts= list(system.hosts.keys())
+        for host,port in hosts:
+            # Prova connexió i comprova id
+            try:
+                addr= '%s:%d'%(host,port)
+                with grpc.insecure_channel(addr) as channel:
+                    stub= asr_system_pb2_grpc.ASRSystemStub(channel)
+                    info= stub.GetSystemInfo(empty_pb2.Empty())
+                    if info.info.id!=system.info.id:
+                        ii(('bad identifier for host %s:%d:'+
+                            ' %s != %s')%(host,port,info.info.id,
+                                          system.info.id))
+                        self.__remove_host(sys_id,host,port,ii)
+                    elif info.info.tag!=system.info.tag:
+                        ii(('bad tag for host %s:%d:'+
+                            ' %s != %s')%(host,port,info.info.tag,
+                                          system.info.tag))
+                        self.__remove_host(sys_id,host,port,ii)
+                    else:
+                        yield stub,info,(host,port)
+            except grpc.RpcError:
+                traceback.print_exc()
+                ii('unable to connect to %s:%d'%(host,port))
+                self.__remove_host(sys_id,host,port,ii)
+            except GeneratorExit:
+                pass
+            except:
+                traceback.print_exc()
+                ii('unknown error')
+    # end __iter_sys_stubs
+
+    def __num_connexions(self,sys_id,ii=II):
+        ret= 0
+        for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
+            ret+= 1
+        return ret
+
+    # NOTA! Pot generar excepcions.
+    def add(self,host,port,ii=II):
+        info= self.__get_sys_info(host,port,ii)
+        with self._lock:
+            sys_id= self._sys_map.get(info.info.id)
+        numcon= 0 if sys_id is None else self.__num_connexions(sys_id)
+        # Add host
+        with self._lock:
+            # Repetit a propòsit!!!
+            sys_id= self._sys_map.get(info.info.id)
+            if sys_id==None:
+                sys_id= len(self._sys)
+                self._sys_map[info.info.id]= sys_id
+                sys= System(info=info.info,hosts={})
+                self._sys.append(sys)
+            elif self._sys[sys_id].info.tag!=info.info.tag:
+                system= self._sys[sys_id]
+                hosts= list(system.hosts.keys())
+                if len(system.hosts)>0 and numcon>0:
+                    raise RuntimeError(
+                        'tag mismatch %s != %s'%(self._sys[sys_id].info.tag,
+                                                 info.info.tag))
+                else:
+                    self._sys[sys_id]= System(info=info.info,hosts={})
+            key= (host,port)
+            self._sys[sys_id].hosts[key]= True
+    # end add
+
+    def get_system_info(self,system_id,ii=II):
+        # Obté el sys_id associat a sytem_id
+        sys_id,system= -1,...
+        with self._lock:
+            for tmp_sys_id,tmp_system in enumerate(self._sys):
+                if tmp_system.info.id==system_id:
+                    sys_id= tmp_sys_id
+                    system= tmp_system
+                    break
+        if sys_id==-1: return False,...
+        # Obté info
+        num_decoders= 0
+        num_decoders_available= 0
+        for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
+            if sinfo.enabled:
+                num_decoders+= sinfo.num_decoders
+                num_decoders_available+= sinfo.num_decoders_available
+        if num_decoders>0:
+            return True,SystemInfo(info=system.info,
+                                   num_decoders=num_decoders,
+                                   num_decoders_available=num_decoders_available,
+                                   id=sys_id)
+        else: return False,...
+    # end get_system_info
+    
+    def get_systems_info(self,ii=II):
+        # Llista de systemes
+        with self._lock:
+            nsys= len(self._sys)
+        # Itera sobre els sistemes
+        for sys_id in range(nsys):
+            # Get info
+            with self._lock:
+                system= self._sys[sys_id]
+            # Pregunta a tots els hosts
+            num_decoders= 0
+            num_decoders_available= 0
+            for stub,sinfo,_ in self.__iter_sys_stubs(sys_id,ii):
+                if sinfo.enabled:
+                    num_decoders+= sinfo.num_decoders
+                    num_decoders_available+= sinfo.num_decoders_available
+            # Return
+            if num_decoders>0:
+                yield SystemInfo(info=system.info,
+                                 num_decoders=num_decoders,
+                                 num_decoders_available=num_decoders_available,
+                                 id=sys_id)
+    # end get_systems_info
+
+    def get_hosts_info(self,ii=II):
+        # Llista de systemes
+        with self._lock:
+            nsys= len(self._sys)
+        # Itera sobre els sistemes
+        for sys_id in range(nsys):
+            # Get info
+            with self._lock:
+                system= self._sys[sys_id]
+            # Pregunta a tots els hosts
+            for stub,sinfo,addr in self.__iter_sys_stubs(sys_id,ii):
+                yield HostInfo(info=system.info,
+                               num_decoders=sinfo.num_decoders,
+                               num_decoders_available=sinfo.num_decoders_available,
+                               host=addr[0],
+                               port=addr[1])
+    # end get_hosts_info
+    
+    # NOTA!!! Pot generar excepcions de tipus DecoderException o
+    # Exception.
+    def decode(self,request_iterator,ii):
+        def ee_format(txt):
+            code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT
+            raise DecoderException(code=code,msg=txt)
+        def ee_unk_system(sys_id):
+            code=asr_common_pb2.DecodeResponse.Status.Code.ERR_UNKNOWN_SYSTEM
+            raise DecoderException(code=code,msg='%d is not a valid ASR system'%sys_id)
+        def transform_data(iterator):
+            try:
+                for o in iterator:
+                    if o.HasField('token'):
+                        yield asr_common_pb2.DataPackage(token=o.token)
+                    elif not o.HasField('data'):
+                        ee_format(('a system_id package was provided'+
+                                   ' when a data package was expected'))
+                    else:
+                        yield asr_common_pb2.DataPackage(data=o.data)
+            except grpc.RpcError:
+                ii(('an expected error ocurred while reading input data,'+
+                    ' probably connexion was closed by client'))
+        # Selecciona sistema
+        p= next(request_iterator,None)
+        if p is None: ee_format('no packages were sent for decoding')
+        if not p.HasField('system_id'):
+            ee_format('First package must contain a system identifier')
+        try:
+            with self._lock:
+                system= self._sys[p.system_id]
+        except:
+            ee_unk_system(p.system_id)
+        ii('starting decoding with system: %s'%system.info.id)
+        # Find for the first available system and decode
+        add_host_info= True
+        input_data= LazzyIterator(transform_data(request_iterator))
+        for stub,sinfo,addr in self.__iter_sys_stubs(p.system_id,ii):
+            if sinfo.enabled and sinfo.num_decoders_available>0:
+                ok= False
+                for o in stub.Decode(input_data()):
+                    if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
+                        ii('connexion stablished with %s:%d'%(addr[0],addr[1]))
+                        input_data.ready()
+                        ok= True
+                        continue
+                    elif o.status.code==asr_common_pb2.DecodeResponse.Status.ERR_NO_RECO_AVAILABLE:
+                        break # Abort decoding
+                    if add_host_info:
+                        tmp= asr_common_pb2.DecodeResponse.HostInfo(host=addr[0],port=addr[1])
+                        new_o= asr_common_pb2.DecodeResponse(status=o.status,
+                                                             hyp_novar=o.hyp_novar,
+                                                             hyp_var=o.hyp_var,
+                                                             eos=o.eos,
+                                                             score=o.score,
+                                                             nframes=o.nframes,
+                                                             host_info=tmp)
+                        add_host_info= False
+                        o= new_o
+                    yield o
+                if ok:
+                    ii('decoding finished')
+                    return # Finish the decoding process
+        # If this point is reached it means no decoder was available
+        code=asr_common_pb2.DecodeResponse.Status.Code.ERR_NO_RECO_AVAILABLE
+        raise DecoderException(code=code,
+                               msg=("no recogniser available for ASR"+
+                                    " system '%s'")%system.info.id)
+    
+    # end decode
+    
+# end DecoderMng

+ 211 - 0
code/server_lib/decoder_multi.py

@@ -0,0 +1,211 @@
+__all__= ['decode']
+
+import concurrent.futures as cf
+import threading
+
+import mllp_grpc.asr_pb2 as asr_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+
+from server_lib.decoder_mng import DecoderException
+
+def ee_format(txt):
+    code=asr_common_pb2.DecodeResponse.Status.Code.ERR_WRONG_FORMAT
+    raise DecoderException(code=code,msg=txt)
+
+class Demultiplexor:
+
+    def __init__(self,request_iterator):
+        self._qs= []
+        self.pos2sysid= []
+        # Els primers paquets són els identificadors
+        for o in request_iterator:
+            if not o.HasField('system_id'):
+                break
+            new_q= [o]
+            self._qs.append(new_q)
+            self.pos2sysid.append(o.system_id)
+        if self._qs==[]: ee_format('no system_id were provided')
+        if not o.HasField('data'):
+            ee_format('no data packages were provided')
+        # Insert in all qs
+        for q in self._qs:
+            q.append(o)
+        # Init state
+        self._it= request_iterator
+        self._stop= False
+        self._counter= self.num_streams
+        self._cv= threading.Condition()
+        
+    # end __init__
+
+    def __refill_nonsecure(self):
+        # Dec and check
+        if self._stop: return
+        self._counter-= 1
+        if self._counter>0: return # L'últim replena
+        # Fill
+        o= next(self._it,None)
+        if o is None: self._stop= True
+        else:
+            self._counter= self.num_streams
+            for q in self._qs:
+                q.append(o)
+        self._cv.notify_all()
+
+    def __refill(self):
+        try:
+            self.__refill_nonsecure()
+        except BaseException as e:
+            self._stop= True
+            self._cv.notify_all()
+            raise e
+    
+    # Se li passa el número de cola
+    def __call__(self,q_id):
+        while True:
+            # Intenta obtindre el següent element
+            with self._cv:
+                q= self._qs[q_id]
+                self._cv.wait_for(lambda: len(q)>0 or self._stop)
+                if self._stop and len(q)==0: return # Finalitzem
+                o= q.pop(0)
+                if q==[]: self.__refill()
+            yield o
+    # end __call__
+            
+    @property
+    def num_streams(self):
+        return len(self._qs)
+    
+# end Demultiplexor
+
+
+class SegmentLangId:
+    def __init__(self,num_langs,seg_id=0):
+        NL= num_langs
+        self._seg_id= seg_id
+        self._next= None
+        self._max= [float('-inf')]*NL
+        self._NL= NL
+        
+    def __get_best(self):
+        lang,score= 0,self._max[0]
+        for l in range(1,self._NL):
+            val= self._max[l]
+            if val>score: lang,score= l,val
+        return lang
+
+    # Torna el millor (seg_id,lang_id) del segment actual
+    def __call__(self,lang_id,score,nframes):
+        if nframes==0: return self._seg_id,self.__get_best()
+        score/= nframes
+        self._max[lang_id]= score
+        return self._seg_id,self.__get_best()
+
+    @property
+    def next(self):
+        if self._next==None:
+            self._next= SegmentLangId(self._NL,self._seg_id+1)
+        return self._next
+    
+# end SegmentLangId
+
+class LangId:
+    def __init__(self,num_langs):
+        lid= SegmentLangId(num_langs)
+        self._current_lang= []
+        for i in range(num_langs):
+            self._current_lang.append(lid)
+
+    # Torna el millor (seg_id,lang_id) del segment actual
+    def __call__(self,lang_id,score,nframes,is_final):
+        current= self._current_lang[lang_id]
+        ret= current(lang_id,score,nframes)
+        if is_final:
+            self._current_lang[lang_id]= current.next
+        return ret
+
+# end HypMng
+
+class ExceptionContainer:
+
+    def __init__(self):
+        self._e= None
+
+    def __bool__(self):
+        return self._e is not None
+
+    def set(self,e):
+        self._e= e
+
+    @property
+    def exception(self):
+        return self._e
+
+# end ExceptionContainer
+
+class Counter:
+
+    def __init__(self,val):
+        self.val= val
+
+    def dec(self):
+        self.val-= 1
+    
+# end Counter
+
+# Pot generar excepcions de tot tipus. Incloent DecoderException.
+def decode(mng,request_iterator,ii):
+    # Launch threads
+    demux= Demultiplexor(request_iterator)
+    output_buffer= []
+    error= ExceptionContainer()
+    cv= threading.Condition()
+    lid= LangId(demux.num_streams)
+    stop_counter= Counter(demux.num_streams)
+    def transcribe(lang_id):
+        def my_ii(txt):
+            ii('[sys_id:%d] %s'%(demux.pos2sysid[lang_id],txt))
+        stop= False
+        for o in mng.decode(demux(lang_id),my_ii):
+            with cv:
+                seg_id,best_lid= lid(lang_id,o.score,o.nframes,o.eos)
+                best_sys_id= demux.pos2sysid[best_lid]
+                curr_sys_id= demux.pos2sysid[lang_id]
+                res= asr_pb2.DecodeMultiResponse(res=o,segment_id=seg_id,
+                                                 best_system=best_sys_id,
+                                                 current_system=curr_sys_id)
+                output_buffer.append(res)
+                cv.notify_all()
+        with cv:
+            stop_counter.dec()
+            cv.notify_all()
+    def transcribe_err(lang_id):
+        try:
+            transcribe(lang_id)
+        except BaseException as e:
+            with cv:
+                if not error:
+                    error.set(e)
+                    cv.notify_all()
+    pool= cf.ThreadPoolExecutor(max_workers=demux.num_streams)
+    for i in range(demux.num_streams):
+        pool.submit(transcribe_err,i)
+    # Processa l'exida
+    stop= False
+    while True:
+        with cv:
+            cv.wait_for(lambda: len(output_buffer)>0 or stop_counter.val==0 or error)
+            if error: stop= True
+            elif len(output_buffer)>0:
+                ret= output_buffer.pop(0)
+            else:
+                assert stop_counter.val==0
+                stop= True
+        if stop: break
+        yield ret
+    # Acaba
+    pool.shutdown()
+    if error:
+        raise error.exception
+# end decode

+ 1 - 0
code/test.txt

@@ -0,0 +1 @@
+Hello world, how are you today?

+ 71 - 0
code/test_asr_server.py

@@ -0,0 +1,71 @@
+import logging as Log
+import grpc,argparse,time,sys,threading
+
+import mllp_grpc.asr_pb2_grpc as asr_pb2_grpc
+import mllp_grpc.asr_pb2 as asr_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+from pydub import AudioSegment
+from math import ceil
+
+def get_data(args,sys_id):
+    Log.info("Loading '%s'"%args.wav)
+    sound= AudioSegment.from_wav(args.wav)
+    
+    if (sound.frame_rate != 16000):
+        sound= sound.set_frame_rate(16000)
+        
+    nbuffers= ceil(len(sound)/args.bsize)
+    Log.info(('c: %d r: %d rate: %d len: %d nbuf:'+
+              ' %d bsize: %d')%(sound.channels,
+                                sound.sample_width,
+                                sound.frame_rate,
+                                len(sound),
+                                nbuffers,
+                                args.bsize))
+    sended= 0
+    Log.info('Sending data')
+    yield asr_pb2.DecodeRequest(system_id=sys_id)
+    for n in range(0,nbuffers):
+        buf= sound[n*args.bsize:(n+1)*args.bsize]
+        yield asr_pb2.DecodeRequest(data=buf.raw_data)
+        sended+= args.bsize
+        if sended>=1000:
+            sended-= 1000
+            time.sleep(1)
+    Log.info('Transmission completed')
+        
+def parse_cmdline():
+    p= argparse.ArgumentParser(description='ASR GRPC client')
+    p.add_argument('host',type=str,help='server host')
+    p.add_argument('port',type=int,help='server port')
+    p.add_argument('wav',type=str,help='input 1 channel wav at 16KHz')
+    p.add_argument('--bsize',type=int,default=250,
+                   help='buffer size')
+    args= p.parse_args()
+    return args
+
+def run(args):
+    addr= '%s:%d'%(args.host,args.port)
+    with grpc.insecure_channel(addr) as channel:
+        stub= asr_pb2_grpc.ASRStub(channel)
+        for info in stub.GetSystemsInfo(empty_pb2.Empty()):
+            print(info)
+        print_host= True
+        for o in stub.Decode(get_data(args,1)):
+            if print_host:
+                print_host= False
+                print(o.host_info)
+            if o.status.code!=asr_common_pb2.DecodeResponse.Status.Code.OK:
+                print(o)
+            elif o.hyp_novar!='':
+                sys.stdout.write(o.hyp_novar+' ')
+            sys.stdout.flush()
+        sys.stdout.write('\n')
+        sys.stdout.flush()
+        
+if __name__=='__main__':
+    Log.basicConfig(level=Log.INFO)
+    args= parse_cmdline()
+    run(args)

+ 91 - 0
code/test_asr_system.py

@@ -0,0 +1,91 @@
+import logging as Log
+import grpc,argparse,time,sys,threading
+
+import mllp_grpc.asr_system_pb2_grpc as asr_system_pb2_grpc
+import mllp_grpc.asr_system_pb2 as asr_system_pb2
+import mllp_grpc.asr_common_pb2 as asr_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+from pydub import AudioSegment
+from math import ceil
+
+class LazzyIterator:
+
+    def __init__(self,i):
+        self._i= i
+        self._cv= threading.Condition()
+        self._ready= False
+
+    def __call__(self):
+        with self._cv:
+            if not self._ready:
+                self._cv.wait()
+        for v in self._i:
+            yield v
+                
+    def ready(self):
+        with self._cv:
+            self._ready= True
+            self._cv.notify()
+        
+# end LazzyIterator
+
+def get_data(args):
+    Log.info("Loading '%s'"%args.wav)
+    sound= AudioSegment.from_wav(args.wav)
+    
+    if (sound.frame_rate != 16000):
+        sound= sound.set_frame_rate(16000)
+        
+    nbuffers= ceil(len(sound)/args.bsize)
+    Log.info(('c: %d r: %d rate: %d len: %d nbuf:'+
+              ' %d bsize: %d')%(sound.channels,
+                                sound.sample_width,
+                                sound.frame_rate,
+                                len(sound),
+                                nbuffers,
+                                args.bsize))
+    sended= 0
+    Log.info('Sending data')
+    for n in range(0,nbuffers):
+        buf= sound[n*args.bsize:(n+1)*args.bsize]
+        yield asr_common_pb2.DataPackage(data=buf.raw_data)
+        sended+= args.bsize
+        if sended>=1000:
+            sended-= 1000
+            time.sleep(1)
+    Log.info('Transmission completed')
+        
+def parse_cmdline():
+    p= argparse.ArgumentParser(description='ASR GRPC client')
+    p.add_argument('host',type=str,help='server host')
+    p.add_argument('port',type=int,help='server port')
+    p.add_argument('wav',type=str,help='input 1 channel wav at 16KHz')
+    p.add_argument('--bsize',type=int,default=250,
+                   help='buffer size')
+    args= p.parse_args()
+    return args
+
+def run(args):
+    addr= '%s:%d'%(args.host,args.port)
+    with grpc.insecure_channel(addr) as channel:
+        stub= asr_system_pb2_grpc.ASRSystemStub(channel)
+        info= stub.GetSystemInfo(empty_pb2.Empty())
+        print(info)
+        data= LazzyIterator(get_data(args))
+        for o in stub.Decode(data()):
+            if o.status.code==asr_common_pb2.DecodeResponse.Status.READY:
+                data.ready()
+                continue
+            if o.hyp_novar!='':
+                sys.stdout.write(o.hyp_novar+' ')
+                sys.stdout.flush()
+            if o.eos:
+                sys.stdout.write('\n')
+                sys.stdout.flush()
+                print('------------------------------------------------')
+        
+if __name__=='__main__':
+    Log.basicConfig(level=Log.INFO)
+    args= parse_cmdline()
+    run(args)

+ 61 - 0
code/test_smt_server.py

@@ -0,0 +1,61 @@
+import logging as Log
+import grpc,argparse,time,sys,threading
+
+import mllp_grpc.sim_mt_pb2_grpc as sim_mt_pb2_grpc
+import mllp_grpc.sim_mt_pb2 as sim_mt_pb2
+import mllp_grpc.sim_mt_common_pb2 as sim_mt_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+
+def get_data(args,sys_id):
+    Log.info("Loading '%s'"%args.txt)
+    Log.info('Sending data')
+    yield sim_mt_pb2.TranslateRequest(system_id=sys_id)
+    with open(args.txt) as f:
+        for l in f:
+            l= l.strip().split()
+            if l==[]: continue
+            for w in l:
+                data= sim_mt_common_pb2.TextPackage(words_novar=[w],
+                                                    words_var=[],
+                                                    eos=False)
+                o= sim_mt_pb2.TranslateRequest(text=data)
+                yield o
+            data= sim_mt_common_pb2.TextPackage(words_novar=[],
+                                                    words_var=[],
+                                                    eos=True)
+            o= sim_mt_pb2.TranslateRequest(text=data)
+            yield o
+    Log.info('Transmission completed')
+        
+def parse_cmdline():
+    p= argparse.ArgumentParser(description='Simultaneous MT GRPC client')
+    p.add_argument('host',type=str,help='server host')
+    p.add_argument('port',type=int,help='server port')
+    p.add_argument('txt',type=str,help='input text file')
+    args= p.parse_args()
+    return args
+
+def run(args):
+    addr= '%s:%d'%(args.host,args.port)
+    with grpc.insecure_channel(addr) as channel:
+        stub= sim_mt_pb2_grpc.SimMTStub(channel)
+        for info in stub.GetSystemsInfo(empty_pb2.Empty()):
+            print(info)
+        print_host= True
+        for o in stub.Translate(get_data(args,0)):
+            if print_host:
+                print_host= False
+                print(o.host_info)
+            if o.status.code!=sim_mt_common_pb2.TranslateResponse.Status.Code.OK:
+                print(o)
+            else:
+                for w in o.words_novar:
+                    sys.stdout.write(w+' ')
+                if o.eos: sys.stdout.write('\n')
+                sys.stdout.flush()
+        
+if __name__=='__main__':
+    Log.basicConfig(level=Log.INFO)
+    args= parse_cmdline()
+    run(args)

+ 80 - 0
code/test_smt_system.py

@@ -0,0 +1,80 @@
+import logging as Log
+import grpc,argparse,time,sys,threading
+
+import mllp_grpc.sim_mt_system_pb2_grpc as sim_mt_system_pb2_grpc
+import mllp_grpc.sim_mt_system_pb2 as sim_mt_system_pb2
+import mllp_grpc.sim_mt_common_pb2 as sim_mt_common_pb2
+import google.protobuf.empty_pb2 as empty_pb2
+
+class LazzyIterator:
+
+    def __init__(self,i):
+        self._i= i
+        self._cv= threading.Condition()
+        self._ready= False
+
+    def __call__(self):
+        with self._cv:
+            if not self._ready:
+                self._cv.wait()
+        for v in self._i:
+            yield v
+                
+    def ready(self):
+        with self._cv:
+            self._ready= True
+            self._cv.notify()
+        
+# end LazzyIterator
+
+def get_data(args):
+    Log.info("Loading '%s'"%args.txt)
+    Log.info('Sending data')
+    with open(args.txt) as f:
+        for l in f:
+            l= l.strip().split()
+            if l==[]: continue
+            for w in l:
+                o= sim_mt_common_pb2.TextPackage(words_novar=[w],
+                                                 words_var=[],
+                                                 eos=False)
+                yield o
+            o= sim_mt_common_pb2.TextPackage(words_novar=[],
+                                             words_var=[],
+                                             eos=True)
+            yield o
+    Log.info('Transmission completed')
+        
+def parse_cmdline():
+    p= argparse.ArgumentParser(description='Simultaneous MT GRPC client')
+    p.add_argument('host',type=str,help='server host')
+    p.add_argument('port',type=int,help='server port')
+    p.add_argument('txt',type=str,help='txt file to translate')
+    args= p.parse_args()
+    return args
+
+def run(args):
+    addr= '%s:%d'%(args.host,args.port)
+    with grpc.insecure_channel(addr) as channel:
+        stub= sim_mt_system_pb2_grpc.SimMTSystemStub(channel)
+        info= stub.GetSystemInfo(empty_pb2.Empty())
+        stub.SetEnabled(sim_mt_system_pb2.SetEnabledRequest(value=True))
+        print(info)
+        data= LazzyIterator(get_data(args))
+        for o in stub.Translate(data()):
+            if o.status.code==sim_mt_common_pb2.TranslateResponse.Status.READY:
+                data.ready()
+                continue
+            if o.status.code!=sim_mt_common_pb2.TranslateResponse.Status.OK:
+                print(o)
+            else:
+                for w in o.words_novar:
+                    sys.stdout.write(w+' ')
+                if o.eos: sys.stdout.write('\n')
+                sys.stdout.flush()
+        print('------------------------------------------------')
+        
+if __name__=='__main__':
+    Log.basicConfig(level=Log.INFO)
+    args= parse_cmdline()
+    run(args)