跳转至

Retrieve

labridge.tools.instrument.retrieve

labridge.tools.instrument.retrieve.InstrumentRetrieverTool

Bases: RetrieverBaseTool

Source code in labridge\tools\instrument\retrieve.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class InstrumentRetrieverTool(RetrieverBaseTool):
	def __init__(
		self,
		llm: LLM = None,
		embed_model: BaseEmbedding = None,
		metadata_mode: MetadataMode = MetadataMode.NONE,
	):
		instrument_retriever = InstrumentRetriever(
			llm=llm,
			embed_model=embed_model,
		)
		self.metadata_mode = metadata_mode
		self.super_user_manager = InstrumentSuperUserManager()
		super().__init__(
			retriever=instrument_retriever,
			name=InstrumentRetrieverTool.__name__,
			retrieve_fn=InstrumentRetriever.retrieve
		)

	def log(self, log_dict: dict) -> ToolLog:
		ref_infos: List[InstrumentInfo] = log_dict[TOOL_LOG_REF_INFO_KEY]
		instrument_infos = [info.dumps() for info in ref_infos]

		log_to_user = None
		log_to_system = {
			TOOL_OP_DESCRIPTION: f"Use the {self.metadata.name} to retrieve the instrument docs.",
			TOOL_REFERENCES: instrument_infos,
		}

		return ToolLog(
			tool_name=self.metadata.name,
			log_to_user=log_to_user,
			log_to_system=log_to_system,
		)

	def _retrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		nodes = self._retriever.retrieve(**retrieve_kwargs)
		return nodes

	async def _aretrieve(self, retrieve_kwargs: dict) -> List[NodeWithScore]:
		nodes = await self._retriever.aretrieve(**retrieve_kwargs)
		return nodes

	def get_ref_info(self, nodes: List[NodeWithScore]) -> List[RefInfoBase]:
		r""" Get the reference infos from the retrieved nodes. """
		instrument_infos = []
		instrument_set = set()
		for node in nodes:
			instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
			# TODO: Add node type and filter.
			if instrument_id == INSTRUMENT_ROOT_NODE_NAME or instrument_id in instrument_set:
				continue
			instrument_set.add(instrument_id)
			super_users = self.super_user_manager.get_super_users(
				instrument_id=instrument_id,
			)
			info = InstrumentInfo(
				instrument_id=instrument_id,
				super_users=super_users,
			)
			instrument_infos.append(info)
		return instrument_infos

	def _nodes_to_tool_output(self, nodes: List[NodeWithScore]) -> Tuple[str, dict]:
		r""" output the retrieved contents in a specific format. """
		output = ""
		header = f"Have retrieved the docs of several relevant instruments:\n\n"
		output += header

		if len(nodes) < 1:
			output += "No relevant instrument contents found."

		ref_infos = self.get_ref_info(nodes=nodes)
		log_dict = {
			TOOL_LOG_REF_INFO_KEY: ref_infos,
		}

		instrument_docs = dict()
		for node in nodes:
			instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
			# TODO: Add node type and filter.
			if instrument_id == INSTRUMENT_ROOT_NODE_NAME:
				continue
			if instrument_id not in instrument_docs:
				instrument_docs[instrument_id] = []
			instrument_docs[instrument_id].append(node)

		for instrument_id in instrument_docs.keys():
			instrument_content = f"Instrument Name: {instrument_id}\n"
			for idx, node in enumerate(instrument_docs[instrument_id]):
				instrument_content += (
					f"Retrieved content {idx + 1}:\n"
					f"{node.node.get_content(metadata_mode=self.metadata_mode)}\n"
			)
			output += f"{instrument_content}\n"
		return output, log_dict

labridge.tools.instrument.retrieve.InstrumentRetrieverTool.get_ref_info(nodes)

Get the reference infos from the retrieved nodes.

Source code in labridge\tools\instrument\retrieve.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def get_ref_info(self, nodes: List[NodeWithScore]) -> List[RefInfoBase]:
	r""" Get the reference infos from the retrieved nodes. """
	instrument_infos = []
	instrument_set = set()
	for node in nodes:
		instrument_id = node.metadata.get(INSTRUMENT_NAME_KEY, node.node_id)
		# TODO: Add node type and filter.
		if instrument_id == INSTRUMENT_ROOT_NODE_NAME or instrument_id in instrument_set:
			continue
		instrument_set.add(instrument_id)
		super_users = self.super_user_manager.get_super_users(
			instrument_id=instrument_id,
		)
		info = InstrumentInfo(
			instrument_id=instrument_id,
			super_users=super_users,
		)
		instrument_infos.append(info)
	return instrument_infos