Skip to content

Efficient Learned Tool Triggers

⬅️ [Geocache app](<./Geocache app.md>) | ⬆️ [Ideas](<./README.md>) | [Disentangled Representation Learning](<./Disentangled Representation Learning.md>) ➡️

Efficient Learned Tool Triggers

Andy's work has a potential failure mode where a tool is never invoked because the AI has no idea that it can/should use the tool. Normal tool calling agents get around this by just having all tools in context so the agent can just see that it should think about using that tool. I assume agents are even trained to better attend to the tool blocks.

In Andy's method, there is a potentially infinite database of tools so obviously we cannot keep a list of contexts to use them in in memory. Instead, we have to rely on the AI just magically knowing that it might benefit from using a tool in this circumstance.

I think it would benefit from knowing apriori that there may be a relevant tool. The basic way to do this would be to learn a head on top of some token on some layer in the transformer that maps into the tool space.

Question: How do we represent the "no tool" situation during training?

We do not propogate back into the LLM so that the head can be left separate and a new head can be trained for each LLM to map into the same tool space without having a specific fine tune of the LLM.

We can train the head using automatically generated data where we feed in a specific tool schema and have an LLM generate a prompt that uses the tool (and potentially specific tool invocation). Then have the LLM do its thinking process and have the end of each sentence be a training sample where we train the head to map to the latent space at the sample that was used to generate the prompt. Repeat.

Again, how do we represent the "no tool" situation. Can we generate adversarial data for tools that don't exist? But what if somebody wanted to add a tool in that place where there currently isn't a tool? I feel like there is simply no "no tool" situation as the space of possible tools is unbounded.

Process

Generate S sample prompts that require a tool call:
1. Select a tool invocation from the dataset of tools
2. Prompt a large LLM to generate a prompt that requires the specific invocation of the tool that aligns with the invocation
3. Store the prompt and tool invocation/tool embedding
4. Repeat steps 1-3 until S samples have been produced

Perhaps see if large LLMs can generate simulated back-and-forth conversations. If that doesn't work well, just use single turn and hope for generalization.

Train the head to project into the tool space:

For hyperparameter tuning, since the head is very small relative to the transformer, we should train multiple heads in parallel on different layers and choose the best in validation.

  1. Sample a batch of B generated prompts
  2. Run samples through LLM to generate thinking trajectories. While doing this, store activations.
  3. Put the activations and corresponding tool embeddings in a buffer
  4. Sample randomly from this buffer to train the head(s) so that we get random selections of tool embeddings instead of the same one multiple times in a row. Samples do not get invalidated over time so we can keep them forever if we want. But storing a rolling buffer is probably the way to go.

Validation:
Use human generated conversations with known best tool invocations as the validation set so bias toward LLM generated prompts can be seen. Also use multi-turn conversations to evaluate if performance drops over back-and-forth conversations not represented in the generated data.

Inference time:
Simplest method:
After each sentence, take the activations and run them through the head. If within a certain distance D from a known tool embedding, either bias sampling toward the dispatch tokens or just inject the summary for the tool in something like

TOOL INVOCATION HELPER  
Suggested tool: Search internet  
END TOOL INVOCATION HELPER  

⬅️ [Geocache app](<./Geocache app.md>) | ⬆️ [Ideas](<./README.md>) | [Disentangled Representation Learning](<./Disentangled Representation Learning.md>) ➡️