跳转至

Retriever

labridge.tools.paper.shared_papers.retriever

labridge.tools.paper.shared_papers.retriever.SharedPaperRetrieverTool

Bases: RetrieverBaseTool

This tool is used to retrieve in the shared papers storage of the laboratory.

Multi-level, hybrid retrieving is used for accurate results. For details of retrieving, refer to the docstring of PaperRetriever.

PARAMETER DESCRIPTION
llm

The used LLM.

TYPE: LLM DEFAULT: None

embed_model

The used embedding model.

TYPE: BaseEmbedding DEFAULT: None

vector_similarity_top_k

The top-k of content-based retrieving. Defaults to PAPER_VECTOR_TOP_K.

TYPE: int DEFAULT: PAPER_VECTOR_TOP_K

summary_similarity_top_k

The top-k of summary-based retrieving. Defaults tp PAPER_SUMMARY_TOP_K.

TYPE: int DEFAULT: PAPER_SUMMARY_TOP_K

docs_top_k

The top-k docs will be selected. Defaults to PAPER_TOP_K.

TYPE: int DEFAULT: PAPER_TOP_K

re_retrieve_top_k

The top-k of retrieving among the selected docs_top_k docs. Defaults to PAPER_RETRIEVE_TOP_K.

TYPE: int DEFAULT: PAPER_RETRIEVE_TOP_K

final_use_context

Whether to use the context nodes of the retrieved nodes as parts of results. Defaults to True.

TYPE: bool DEFAULT: True

final_use_summary

Whether to use the summary nodes of the retrieved nodes' relevant docs as parts of results. Defaults to True.

TYPE: bool DEFAULT: True

Source code in labridge\tools\paper\shared_papers\retriever.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class SharedPaperRetrieverTool(RetrieverBaseTool):
	r"""
	This tool is used to retrieve in the shared papers storage of the laboratory.

	Multi-level, hybrid retrieving is used for accurate results.
	For details of retrieving, refer to the docstring of `PaperRetriever`.

	Args:
		llm (LLM): The used LLM.
		embed_model (BaseEmbedding): The used embedding model.
		vector_similarity_top_k (int): The top-k of content-based retrieving. Defaults to `PAPER_VECTOR_TOP_K`.
		summary_similarity_top_k (int): The top-k of summary-based retrieving. Defaults tp `PAPER_SUMMARY_TOP_K`.
		docs_top_k (int): The top-k docs will be selected. Defaults to `PAPER_TOP_K`.
		re_retrieve_top_k (int): The top-k of retrieving among the selected `docs_top_k` docs.
			Defaults to `PAPER_RETRIEVE_TOP_K`.
		final_use_context (bool): Whether to use the context nodes of the retrieved nodes as parts of results.
			Defaults to True.
		final_use_summary (bool): Whether to use the summary nodes of the retrieved nodes' relevant docs as parts of results.
			Defaults to True.
	"""
	def __init__(
		self,
		llm: LLM = None,
		embed_model: BaseEmbedding = None,
		vector_similarity_top_k: int = PAPER_VECTOR_TOP_K,
		summary_similarity_top_k: int = PAPER_SUMMARY_TOP_K,
		docs_top_k: int = PAPER_TOP_K,
		re_retrieve_top_k: int = PAPER_RETRIEVE_TOP_K,
		final_use_context: bool = True,
		final_use_summary: bool = True,
	):
		self._llm = llm or Settings.llm
		self._embed_model = embed_model or Settings.embed_model
		paper_retriever = SharedPaperRetriever.from_storage(
			llm=self._llm,
			embed_model=self._embed_model,
		)

		# paper_retriever = PaperRetriever.from_storage(
		# 	llm=self._llm,
		# 	embed_model=self._embed_model,
		# 	vector_similarity_top_k=vector_similarity_top_k,
		# 	summary_similarity_top_k=summary_similarity_top_k,
		# 	docs_top_k=docs_top_k,
		# 	re_retrieve_top_k=re_retrieve_top_k,
		# 	final_use_context=final_use_context,
		# 	final_use_summary=final_use_summary,
		# )
		super().__init__(
			name=SharedPaperRetrieverTool.__name__,
			retriever=paper_retriever,
			retrieve_fn=paper_retriever.retrieve,
		)
		root = Path(__file__)
		for i in range(5):
			root = root.parent
		self.root = root

	def log(self, log_dict: dict) -> ToolLog:
		r""" Return the ToolLog with log string in a specific format. """
		item_to_be_retrieved = log_dict["item_to_be_retrieved"]

		ref_infos: List[PaperInfo] = log_dict.get(TOOL_LOG_REF_INFO_KEY)

		op_log = (
			f"Retrieve in the shared papers.\n"
			f"retrieve string: {item_to_be_retrieved}\n"
		)
		log_to_user = None
		log_to_system = {
			TOOL_OP_DESCRIPTION: op_log,
			TOOL_REFERENCES: [ref_info.dumps() for ref_info in ref_infos]
		}
		return ToolLog(
			tool_name=self.metadata.name,
			log_to_user=log_to_user,
			log_to_system=log_to_system,
		)

	def _retrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		r""" Use the retriever to retrieve relevant nodes. """
		nodes = self._retriever.retrieve(**retrieve_kwargs)
		return nodes

	async def _aretrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		r""" Asynchronously use the retriever to retrieve relevant nodes. """
		nodes = await self._retriever.aretrieve(**retrieve_kwargs)
		return nodes

	def get_ref_info(self, nodes: List[NodeWithScore]) -> List[PaperInfo]:
		r"""
		Get the reference paper infos

		Returns:
			List[PaperInfo]: The reference paper infos in answering.
		"""
		doc_ids, doc_titles, doc_possessors = [], [], []
		ref_infos = []
		for node_score in nodes:
			ref_doc_id = node_score.node.ref_doc_id
			if ref_doc_id and ref_doc_id not in doc_ids:
				doc_ids.append(ref_doc_id)
				title = node_score.node.metadata.get(PAPER_TITLE) or ref_doc_id
				possessor = node_score.node.metadata.get(PAPER_POSSESSOR)
				rel_path = node_score.node.metadata.get(PAPER_REL_FILE_PATH)
				doi = node_score.node.metadata.get(PAPER_DOI)
				if rel_path is None:
					raise ValueError("Invalid database.")
				paper_info = PaperInfo(
					title=title,
					possessor=possessor,
					file_path=str(self.root / rel_path),
					doi=doi,
				)
				ref_infos.append(paper_info)
				doc_titles.append(title)
				doc_possessors.append(possessor)
		return ref_infos

	def _nodes_to_tool_output(self, nodes: List[NodeWithScore]) -> Tuple[str, dict]:
		r""" output the retrieved contents in a specific format, and the output log. """
		ref_infos = self.get_ref_info(nodes=nodes)
		log_dict = {
			TOOL_LOG_REF_INFO_KEY: ref_infos,
		}

		paper_contents = {}
		for node in nodes:
			doc_name = node.node.ref_doc_id
			if doc_name not in paper_contents:
				paper_contents[doc_name] = [node.get_content(metadata_mode=MetadataMode.LLM)]
			else:
				paper_contents[doc_name].append(node.get_content(metadata_mode=MetadataMode.LLM))

		if paper_contents:
			content_str = "Have retrieved the following contents: \n"
			contents = []
			for doc_name in paper_contents.keys():
				each_str = f"Following contents are from the paper: {doc_name}:\n"
				each_str += "\n".join(paper_contents[doc_name])
				contents.append(each_str.strip())
			content_str += "\n\n".join(contents)
		else:
			content_str = "Have retrieved nothing.\n"
		return content_str, log_dict

labridge.tools.paper.shared_papers.retriever.SharedPaperRetrieverTool.get_ref_info(nodes)

Get the reference paper infos

RETURNS DESCRIPTION
List[PaperInfo]

List[PaperInfo]: The reference paper infos in answering.

Source code in labridge\tools\paper\shared_papers\retriever.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_ref_info(self, nodes: List[NodeWithScore]) -> List[PaperInfo]:
	r"""
	Get the reference paper infos

	Returns:
		List[PaperInfo]: The reference paper infos in answering.
	"""
	doc_ids, doc_titles, doc_possessors = [], [], []
	ref_infos = []
	for node_score in nodes:
		ref_doc_id = node_score.node.ref_doc_id
		if ref_doc_id and ref_doc_id not in doc_ids:
			doc_ids.append(ref_doc_id)
			title = node_score.node.metadata.get(PAPER_TITLE) or ref_doc_id
			possessor = node_score.node.metadata.get(PAPER_POSSESSOR)
			rel_path = node_score.node.metadata.get(PAPER_REL_FILE_PATH)
			doi = node_score.node.metadata.get(PAPER_DOI)
			if rel_path is None:
				raise ValueError("Invalid database.")
			paper_info = PaperInfo(
				title=title,
				possessor=possessor,
				file_path=str(self.root / rel_path),
				doi=doi,
			)
			ref_infos.append(paper_info)
			doc_titles.append(title)
			doc_possessors.append(possessor)
	return ref_infos

labridge.tools.paper.shared_papers.retriever.SharedPaperRetrieverTool.log(log_dict)

Return the ToolLog with log string in a specific format.

Source code in labridge\tools\paper\shared_papers\retriever.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def log(self, log_dict: dict) -> ToolLog:
	r""" Return the ToolLog with log string in a specific format. """
	item_to_be_retrieved = log_dict["item_to_be_retrieved"]

	ref_infos: List[PaperInfo] = log_dict.get(TOOL_LOG_REF_INFO_KEY)

	op_log = (
		f"Retrieve in the shared papers.\n"
		f"retrieve string: {item_to_be_retrieved}\n"
	)
	log_to_user = None
	log_to_system = {
		TOOL_OP_DESCRIPTION: op_log,
		TOOL_REFERENCES: [ref_info.dumps() for ref_info in ref_infos]
	}
	return ToolLog(
		tool_name=self.metadata.name,
		log_to_user=log_to_user,
		log_to_system=log_to_system,
	)