Add an option to log result from the Agent (#23454)
This commit is contained in:
parent
f69589d1bc
commit
40ed18ae15
|
@ -207,6 +207,7 @@ class Agent:
|
|||
self.chat_prompt_template = CHAT_PROMPT_TEMPLATE if chat_prompt_template is None else chat_prompt_template
|
||||
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
|
||||
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
|
||||
self.log = print
|
||||
if additional_tools is not None:
|
||||
if isinstance(additional_tools, (list, tuple)):
|
||||
additional_tools = {t.name: t for t in additional_tools}
|
||||
|
@ -244,6 +245,15 @@ class Agent:
|
|||
prompt = prompt.replace("<<prompt>>", task)
|
||||
return prompt
|
||||
|
||||
def set_stream(self, streamer):
|
||||
"""
|
||||
Set the function use to stream results (which is `print` by default).
|
||||
|
||||
Args:
|
||||
streamer (`callable`): The function to call when streaming results from the LLM.
|
||||
"""
|
||||
self.log = streamer
|
||||
|
||||
def chat(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a new request to the agent in a chat. Will use the previous ones in its history.
|
||||
|
@ -273,12 +283,12 @@ class Agent:
|
|||
self.chat_history = prompt + result.strip() + "\n"
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
self.log(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
if code is not None:
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
self.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.log("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
self.chat_state.update(kwargs)
|
||||
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
|
||||
|
@ -320,11 +330,11 @@ class Agent:
|
|||
result = self.generate_one(prompt, stop=["Task:"])
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
self.log(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
self.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.log("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
return evaluate(code, self.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
|
@ -487,7 +497,7 @@ class HfAgent(Agent):
|
|||
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code == 429:
|
||||
print("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
logger.info("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
time.sleep(1)
|
||||
return self._generate_one(prompt)
|
||||
elif response.status_code != 200:
|
||||
|
|
Loading…
Reference in New Issue