Source code for harvesttext.ent_retrieve

import numpy as np
from collections import defaultdict

[docs]class EntRetrieveMixin: """ 实体检索模块: - 基于倒排索引快速检索包括某个实体的文档,以及统计出现某实体的文档数目 """
[docs] def build_index(self, docs, with_entity=True, with_type=True): if len(self.entity_type_dict) == 0: raise Exception("请先使用add_entities等函数添加希望关注的实体,再进行索引和检索") inv_index = defaultdict(set) for i, sent in enumerate(docs): entities_info = self.entity_linking(sent) for span, (entity, type0) in entities_info: if with_entity: inv_index[entity].add(i) if with_type: inv_index[type0].add(i) return inv_index
[docs] def get_entity_counts(self, docs, inv_index, used_type=[]): if len(inv_index) == 0: raise Exception("请先使用add_entities等函数添加希望关注的实体,再进行索引和检索") if len(used_type) > 0: entities = iter(x for x in self.entity_type_dict if self.entity_type_dict[x] in used_type) else: entities = self.entity_type_dict.keys() cnt = {enty: len(inv_index[enty]) for enty in entities if enty in inv_index} return cnt
[docs] def search_entity(self, query, docs, inv_index): if len(inv_index) == 0: raise Exception("请先使用add_entities等函数添加希望关注的实体,再进行索引和检索") words = query.split() if len(words) > 0: ids = inv_index[words[0]] for word in words[1:]: ids = ids & inv_index[word] np_docs = np.array(docs)[list(ids)] return np_docs.tolist() else: return []