robot.ASR 源代码

# -*- coding: utf-8-*-
import json
from aip import AipSpeech
from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech
from . import utils, config
from robot import logging
from abc import ABCMeta, abstractmethod

logger = logging.getLogger(__name__)

[文档]class AbstractASR(object): """ Generic parent class for all ASR engines """ __metaclass__ = ABCMeta
[文档] @classmethod def get_config(cls): return {}
[文档] @classmethod def get_instance(cls): profile = cls.get_config() instance = cls(**profile) return instance
[文档] @abstractmethod def transcribe(self, fp): pass
[文档]class BaiduASR(AbstractASR): """ 百度的语音识别API. dev_pid: - 1936: 普通话远场 - 1536:普通话(支持简单的英文识别) - 1537:普通话(纯中文识别) - 1737:英语 - 1637:粤语 - 1837:四川话 要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号, 之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key 填入 config.xml 中. ... baidu_yuyin: appid: '9670645' api_key: 'qg4haN8b2bGvFtCbBGqhrmZy' secret_key: '585d4eccb50d306c401d7df138bb02e7' ... """ SLUG = "baidu-asr" def __init__(self, appid, api_key, secret_key, dev_pid=1936, **args): super(self.__class__, self).__init__() self.client = AipSpeech(appid, api_key, secret_key) self.dev_pid = dev_pid
[文档] @classmethod def get_config(cls): # Try to get baidu_yuyin config from config return config.get('baidu_yuyin', {})
[文档] def transcribe(self, fp): # 识别本地文件 pcm = utils.get_pcm_from_wav(fp) res = self.client.asr(pcm, 'pcm', 16000, { 'dev_pid': self.dev_pid, }) if res['err_no'] == 0: logger.info('{} 语音识别到了:{}'.format(self.SLUG, res['result'])) return ''.join(res['result']) else: logger.info('{} 语音识别出错了: {}'.format(self.SLUG, res['err_msg'])) return ''
[文档]class TencentASR(AbstractASR): """ 腾讯的语音识别API. """ SLUG = "tencent-asr" def __init__(self, appid, secretid, secret_key, region='ap-guangzhou', **args): super(self.__class__, self).__init__() self.engine = TencentSpeech.tencentSpeech(secret_key, secretid) self.region = region
[文档] @classmethod def get_config(cls): # Try to get tencent_yuyin config from config return config.get('tencent_yuyin', {})
[文档] def transcribe(self, fp): mp3_path = utils.convert_wav_to_mp3(fp) r = self.engine.ASR(mp3_path, 'mp3', '1', self.region) utils.check_and_delete(mp3_path) res = json.loads(r) if 'Response' in res and 'Result' in res['Response']: logger.info('{} 语音识别到了:{}'.format(self.SLUG, res['Response']['Result'])) return res['Response']['Result'] else: logger.critical('{} 语音识别出错了'.format(self.SLUG), exc_info=True) return ''
[文档]class XunfeiASR(AbstractASR): """ 科大讯飞的语音识别API. 外网ip查询:https://ip.51240.com/ """ SLUG = "xunfei-asr" def __init__(self, appid, asr_api_key, asr_api_secret, tts_api_key, voice='xiaoyan'): super(self.__class__, self).__init__() self.appid = appid self.api_key = asr_api_key self.api_secret = asr_api_secret
[文档] @classmethod def get_config(cls): # Try to get xunfei_yuyin config from config return config.get('xunfei_yuyin', {})
[文档] def transcribe(self, fp): return XunfeiSpeech.transcribe(fp, self.appid, self.api_key, self.api_secret)
[文档]class AliASR(AbstractASR): """ 阿里的语音识别API. """ SLUG = "ali-asr" def __init__(self, appKey, token, **args): super(self.__class__, self).__init__() self.appKey, self.token = appKey, token
[文档] @classmethod def get_config(cls): # Try to get ali_yuyin config from config return config.get('ali_yuyin', {})
[文档] def transcribe(self, fp): result = AliSpeech.asr(self.appKey, self.token, fp) if result is not None: logger.info('{} 语音识别到了:{}'.format(self.SLUG, result)) return result else: logger.critical('{} 语音识别出错了'.format(self.SLUG), exc_info=True) return ''
[文档]def get_engine_by_slug(slug=None): """ Returns: An ASR Engine implementation available on the current platform Raises: ValueError if no speaker implementation is supported on this platform """ if not slug or type(slug) is not str: raise TypeError("无效的 ASR slug '%s'", slug) selected_engines = list(filter(lambda engine: hasattr(engine, "SLUG") and engine.SLUG == slug, get_engines())) if len(selected_engines) == 0: raise ValueError("错误:找不到名为 {} 的 ASR 引擎".format(slug)) else: if len(selected_engines) > 1: logger.warning("注意: 有多个 ASR 名称与指定的引擎名 {} 匹配").format(slug) engine = selected_engines[0] logger.info("使用 {} ASR 引擎".format(engine.SLUG)) return engine.get_instance()
[文档]def get_engines(): def get_subclasses(cls): subclasses = set() for subclass in cls.__subclasses__(): subclasses.add(subclass) subclasses.update(get_subclasses(subclass)) return subclasses return [engine for engine in list(get_subclasses(AbstractASR)) if hasattr(engine, 'SLUG') and engine.SLUG]