Add an option to log result from the Agent (#23454)

This commit is contained in:
Sylvain Gugger 2023-05-18 14:06:49 -04:00 committed by GitHub
parent f69589d1bc
commit 40ed18ae15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 7 deletions

View File

@ -207,6 +207,7 @@ class Agent:
self.chat_prompt_template = CHAT_PROMPT_TEMPLATE if chat_prompt_template is None else chat_prompt_template
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
self.log = print
if additional_tools is not None:
if isinstance(additional_tools, (list, tuple)):
additional_tools = {t.name: t for t in additional_tools}
@ -244,6 +245,15 @@ class Agent:
prompt = prompt.replace("<<prompt>>", task)
return prompt
def set_stream(self, streamer):
"""
Set the function use to stream results (which is `print` by default).
Args:
streamer (`callable`): The function to call when streaming results from the LLM.
"""
self.log = streamer
def chat(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a new request to the agent in a chat. Will use the previous ones in its history.
@ -273,12 +283,12 @@ class Agent:
self.chat_history = prompt + result.strip() + "\n"
explanation, code = clean_code_for_chat(result)
print(f"==Explanation from the agent==\n{explanation}")
self.log(f"==Explanation from the agent==\n{explanation}")
if code is not None:
print(f"\n\n==Code generated by the agent==\n{code}")
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
self.chat_state.update(kwargs)
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
@ -320,11 +330,11 @@ class Agent:
result = self.generate_one(prompt, stop=["Task:"])
explanation, code = clean_code_for_run(result)
print(f"==Explanation from the agent==\n{explanation}")
self.log(f"==Explanation from the agent==\n{explanation}")
print(f"\n\n==Code generated by the agent==\n{code}")
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy())
else:
@ -487,7 +497,7 @@ class HfAgent(Agent):
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
logger.info("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200: