跳转至

Base

labridge.func_modules.memory.base

labridge.func_modules.memory.base.LogBaseRetriever

Bases: object

This is the base class for log-type information retriever, such as chat history and experiment log.

The attributes memory and memory_vector_retriever should be specified in the subclass, and they will be updated in the method retrieve.

PARAMETER DESCRIPTION
embed_model

The used embedding model.

TYPE: BaseEmbedding

final_use_context

Whether to use the context nodes of the retrieved nodes as the final results.

TYPE: bool

relevant_top_k

The top-k relevant retrieved nodes will be used.

TYPE: int

Note

The docstring of the Method retrieve will be used as the tool description of the corresponding retriever tool.

Source code in labridge\func_modules\memory\base.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class LogBaseRetriever(object):
	r"""
	This is the base class for log-type information retriever, such as chat history and experiment log.

	The attributes `memory` and `memory_vector_retriever` should be specified in the subclass,
	and they will be updated in the method `retrieve`.

	Args:
		embed_model (BaseEmbedding): The used embedding model.
		final_use_context (bool): Whether to use the context nodes of the retrieved nodes as the final results.
		relevant_top_k (int): The top-k relevant retrieved nodes will be used.

	Note:
		The docstring of the Method `retrieve` will be used as the tool description of the corresponding
		retriever tool.
	"""
	def __init__(
		self,
		embed_model: BaseEmbedding,
		final_use_context: bool,
		relevant_top_k: int,
	):
		self.memory = None
		self.memory_vector_retriever = None
		self.embed_model = embed_model or Settings.embed_model
		self.final_use_context = final_use_context
		self.relevant_top_k = relevant_top_k

	def _parse_date(self, start_date_str: str, end_date_str: str) -> List[str]:
		r"""
		Get the strings of dates that between the start date and the end date (including them).

		Args:
			start_date_str (str): The string of the start date in a specific format, specified in `common.utils.time`.
			end_date_str (str): The string of the end date.

		Returns:
		"""
		return parse_date_list(
			start_date_str=start_date_str,
			end_date_str=end_date_str,
		)

	@abstractmethod
	def get_memory_vector_retriever(self) -> VectorIndexRetriever:
		r""" Get the vector index retriever from the memory """

	@abstractmethod
	def get_memory_vector_index(self) -> VectorStoreIndex:
		r""" Get the vector index """

	def get_date_filter(self, date_list: List[str]) -> MetadataFilter:
		r"""
		Return the MetadataFilter that filters nodes with dates in the date_list.

		Args:
			date_list (List[str]): The candidate date strings.

		Returns:
			MetadataFilter: The date filter.
		"""
		date_filter = MetadataFilter(
			key=LOG_DATE_NAME,
			value=date_list,
			operator=FilterOperator.ANY,
		)
		return date_filter

	def _log_node_filter(self) -> MetadataFilter:
		r"""
		Return the filter that filters `LOG_NODE_TYPE` nodes.

		Returns:
			The node_type filter.
		"""
		log_type_filter = MetadataFilter(
			key=MEMORY_NODE_TYPE_NAME,
			value=LOG_NODE_TYPE,
			operator=FilterOperator.EQ,
		)
		return log_type_filter

	def sort_retrieved_nodes(
		self,
		memory_nodes: List[NodeWithScore],
		descending: bool = False,
	) -> List[NodeWithScore]:
		r"""
		Sort the retrieved nodes according datetime.

		Args:
			memory_nodes (List[NodeWithScore]): The retrieved nodes.
			descending (bool): Sort in descending order. Defaults to False.

		Returns:
			List[NodeWithScore]: The sorted nodes.
		"""
		if len(memory_nodes) < 1:
			return []
		nodes_datetime = []
		for node in memory_nodes:
			node_date_str = node.node.metadata[LOG_DATE_NAME][0]
			node_time_str = node.node.metadata[LOG_TIME_NAME][0]
			nodes_datetime.append(str_to_datetime(date_str=node_date_str, time_str=node_time_str))

		sorted_items = sorted(zip(memory_nodes, nodes_datetime), key=lambda x: x[1], reverse=descending)
		sorted_nodes, sorted_datetime = zip(*sorted_items)
		return sorted_nodes

	def _add_context(self, content_nodes: List[NodeWithScore]) -> List[NodeWithScore]:
		r"""
		Add the 1-hop context nodes of each content node and keep the QA time order.
		Only the context nodes whose date is the same as the retrieved node will be added.

		Args:
			content_nodes (List[NodeWithScore]): The retrieved nodes.

		Returns:
			List[NodeWithScore]: The final nodes including the context nodes.
		"""
		existing_ids = [node.node.node_id for node in content_nodes]
		final_nodes = []
		vector_index = self.get_memory_vector_index()
		for node in content_nodes:
			# print(node.get_content())
			node_date = node.node.metadata[LOG_DATE_NAME]
			prev_node_info = node.node.prev_node
			next_node_info = node.node.next_node
			if prev_node_info is not None:
				prev_id = prev_node_info.node_id
				prev_node = vector_index.docstore.get_node(prev_id)
				if prev_id not in existing_ids and prev_node.metadata[LOG_DATE_NAME] == node_date:
					existing_ids.append(prev_id)
					final_nodes.append(NodeWithScore(node=prev_node))

			final_nodes.append(node)

			if next_node_info is not None:
				next_id = next_node_info.node_id
				next_node = vector_index.docstore.get_node(next_id)
				if next_id not in existing_ids and next_node.metadata[LOG_DATE_NAME] == node_date:
					existing_ids.append(next_id)
					final_nodes.append(NodeWithScore(node=next_node))
		final_nodes = self.sort_retrieved_nodes(memory_nodes=final_nodes)
		return final_nodes

	@dispatcher.span
	@abstractmethod
	def retrieve(
		self,
		item_to_be_retrieved: str,
		memory_id: str,
		start_date: str = None,
		end_date: str = None,
		**kwargs: Any,
	) -> List[NodeWithScore]:
		r"""
		The docstring of this Method will be used as the tool description of the corresponding retriever tool.
		"""

	@dispatcher.span
	@abstractmethod
	async def aretrieve(
		self,
		item_to_be_retrieved: str,
		memory_id: str,
		start_date: str = None,
		end_date: str = None,
		**kwargs: Any,
	) -> List[NodeWithScore]:
		r"""
		The docstring of this Method will be used as the tool description of the corresponding retriever tool.
		"""

labridge.func_modules.memory.base.LogBaseRetriever.aretrieve(item_to_be_retrieved, memory_id, start_date=None, end_date=None, **kwargs) abstractmethod async

The docstring of this Method will be used as the tool description of the corresponding retriever tool.

Source code in labridge\func_modules\memory\base.py
193
194
195
196
197
198
199
200
201
202
203
204
205
@dispatcher.span
@abstractmethod
async def aretrieve(
	self,
	item_to_be_retrieved: str,
	memory_id: str,
	start_date: str = None,
	end_date: str = None,
	**kwargs: Any,
) -> List[NodeWithScore]:
	r"""
	The docstring of this Method will be used as the tool description of the corresponding retriever tool.
	"""

labridge.func_modules.memory.base.LogBaseRetriever.get_date_filter(date_list)

Return the MetadataFilter that filters nodes with dates in the date_list.

PARAMETER DESCRIPTION
date_list

The candidate date strings.

TYPE: List[str]

RETURNS DESCRIPTION
MetadataFilter

The date filter.

TYPE: MetadataFilter

Source code in labridge\func_modules\memory\base.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def get_date_filter(self, date_list: List[str]) -> MetadataFilter:
	r"""
	Return the MetadataFilter that filters nodes with dates in the date_list.

	Args:
		date_list (List[str]): The candidate date strings.

	Returns:
		MetadataFilter: The date filter.
	"""
	date_filter = MetadataFilter(
		key=LOG_DATE_NAME,
		value=date_list,
		operator=FilterOperator.ANY,
	)
	return date_filter

labridge.func_modules.memory.base.LogBaseRetriever.get_memory_vector_index() abstractmethod

Get the vector index

Source code in labridge\func_modules\memory\base.py
80
81
82
@abstractmethod
def get_memory_vector_index(self) -> VectorStoreIndex:
	r""" Get the vector index """

labridge.func_modules.memory.base.LogBaseRetriever.get_memory_vector_retriever() abstractmethod

Get the vector index retriever from the memory

Source code in labridge\func_modules\memory\base.py
76
77
78
@abstractmethod
def get_memory_vector_retriever(self) -> VectorIndexRetriever:
	r""" Get the vector index retriever from the memory """

labridge.func_modules.memory.base.LogBaseRetriever.retrieve(item_to_be_retrieved, memory_id, start_date=None, end_date=None, **kwargs) abstractmethod

The docstring of this Method will be used as the tool description of the corresponding retriever tool.

Source code in labridge\func_modules\memory\base.py
179
180
181
182
183
184
185
186
187
188
189
190
191
@dispatcher.span
@abstractmethod
def retrieve(
	self,
	item_to_be_retrieved: str,
	memory_id: str,
	start_date: str = None,
	end_date: str = None,
	**kwargs: Any,
) -> List[NodeWithScore]:
	r"""
	The docstring of this Method will be used as the tool description of the corresponding retriever tool.
	"""

labridge.func_modules.memory.base.LogBaseRetriever.sort_retrieved_nodes(memory_nodes, descending=False)

Sort the retrieved nodes according datetime.

PARAMETER DESCRIPTION
memory_nodes

The retrieved nodes.

TYPE: List[NodeWithScore]

descending

Sort in descending order. Defaults to False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
List[NodeWithScore]

List[NodeWithScore]: The sorted nodes.

Source code in labridge\func_modules\memory\base.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def sort_retrieved_nodes(
	self,
	memory_nodes: List[NodeWithScore],
	descending: bool = False,
) -> List[NodeWithScore]:
	r"""
	Sort the retrieved nodes according datetime.

	Args:
		memory_nodes (List[NodeWithScore]): The retrieved nodes.
		descending (bool): Sort in descending order. Defaults to False.

	Returns:
		List[NodeWithScore]: The sorted nodes.
	"""
	if len(memory_nodes) < 1:
		return []
	nodes_datetime = []
	for node in memory_nodes:
		node_date_str = node.node.metadata[LOG_DATE_NAME][0]
		node_time_str = node.node.metadata[LOG_TIME_NAME][0]
		nodes_datetime.append(str_to_datetime(date_str=node_date_str, time_str=node_time_str))

	sorted_items = sorted(zip(memory_nodes, nodes_datetime), key=lambda x: x[1], reverse=descending)
	sorted_nodes, sorted_datetime = zip(*sorted_items)
	return sorted_nodes