
SpringAI技术记录-FunctionCall
这个是工具(Function Calling)-阿里云Spring AI Alibaba官网官网 的运行图
具体代码在 ChatModel
的内部。以 call
方法的调用为例:
//源码 /Users/xuanmiss/.m2/repository/com/alibaba/cloud/ai/spring-ai-alibaba-core/1.0.0-M5.1/spring-ai-alibaba-core-1.0.0-M5.1.jar!/com/alibaba/cloud/ai/dashscope/chat/DashScopeChatModel.class
public ChatResponse call(Prompt prompt) {
ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(DashScopeApiConstants.PROVIDER_NAME).requestOptions((ChatOptions)(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)).build();
ChatResponse chatResponse = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
return observationContext;
}, this.observationRegistry).observe(() -> {
DashScopeApi.ChatCompletionRequest request = this.createRequest(prompt, false);
ResponseEntity<DashScopeApi.ChatCompletion> completionEntity = (ResponseEntity)this.retryTemplate.execute((ctx) -> {
return this.dashscopeApi.chatCompletionEntity(request);
});
DashScopeApi.ChatCompletion chatCompletion = (DashScopeApi.ChatCompletion)completionEntity.getBody();
if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
} else {
List<DashScopeApi.ChatCompletionOutput.Choice> choices = chatCompletion.output().choices();
List<Generation> generations = choices.stream().map((choice) -> {
Map<String, Object> metadata = Map.of("id", chatCompletion.requestId(), "role", choice.message().role() != null ? choice.message().role().name() : "", "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
}).toList();
ChatResponse response = new ChatResponse(generations, this.from((DashScopeApi.ChatCompletion)completionEntity.getBody()));
observationContext.setResponse(response);
return response;
}
});
if (this.isToolCall(chatResponse, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
List<Message> toolCallConversation = this.handleToolCalls(prompt, chatResponse);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
} else {
return chatResponse;
}
}
这里判断模型的返回,如果是 {"message":{"tool_calls"
这样的内容,则执行 function/tool call
继续往下看执行代码和如何找到我们注册的tool和如何调用
// 源码/Users/xuanmiss/.m2/repository/org/springframework/ai/spring-ai-core/1.0.0-M5/spring-ai-core-1.0.0-M5.jar!/org/springframework/ai/chat/model/AbstractToolCallSupport.class
protected List<Message> handleToolCalls(Prompt prompt, ChatResponse response) {
Optional<Generation> toolCallGeneration = response.getResults().stream().filter((g) -> {
return !CollectionUtils.isEmpty(g.getOutput().getToolCalls());
}).findFirst();
if (toolCallGeneration.isEmpty()) {
throw new IllegalStateException("No tool call generation found in the response!");
} else {
AssistantMessage assistantMessage = ((Generation)toolCallGeneration.get()).getOutput();
Map<String, Object> toolContextMap = Map.of();
ChatOptions var7 = prompt.getOptions();
if (var7 instanceof FunctionCallingOptions) {
FunctionCallingOptions functionCallOptions = (FunctionCallingOptions)var7;
if (!CollectionUtils.isEmpty(functionCallOptions.getToolContext())) {
toolContextMap = new HashMap(functionCallOptions.getToolContext());
List<Message> toolCallHistory = new ArrayList(prompt.copy().getInstructions());
toolCallHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), assistantMessage.getToolCalls()));
((Map)toolContextMap).put("TOOL_CALL_HISTORY", toolCallHistory);
}
}
ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage, new ToolContext((Map)toolContextMap));
List<Message> toolConversationHistory = this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);
return toolConversationHistory;
}
}
// ······
protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage, ToolContext toolContext) {
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList();
Iterator var4 = assistantMessage.getToolCalls().iterator();
while(var4.hasNext()) {
AssistantMessage.ToolCall toolCall = (AssistantMessage.ToolCall)var4.next();
String functionName = toolCall.name();
String functionArguments = toolCall.arguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments, toolContext);
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, functionResponse));
}
return new ToolResponseMessage(toolResponses, Map.of());
}
在这个方法里做了 tools
的调用,主要是调用 ToolResponseMessage toolMessageResponse = this.executeFunctions(assistantMessage, new ToolContext((Map)toolContextMap));
执行,然后构造完整的 message
,是同一个类的方法,主要是从 String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments, toolContext);
这里拿到方法的名称,执行call。这里其实是调用了 FunctionCallback.java
这个接口,来实现。下面看看这个接口的内容和具体的两个实现。能够跟 Tools
的两种注册方式对应。
这个 Interface
的源码如下
String call(String functionInput);
default String call(String functionInput, ToolContext tooContext) {
if (tooContext != null && !tooContext.getContext().isEmpty()) {
throw new UnsupportedOperationException("Function context is not supported!");
} else {
return this.call(functionInput);
}
}
这里看看默认提供的实现类和相关逻辑:
这里一共是三个实现类,其中第二个是 FunctionDefinition
,用来提供定义的,没有具体的功能实现。
第一个 AbstractFunctionCallback
实现如下,如果我们通过定义 FunctionCallBack
或者注册了实现Function接口的bean,最终都会由这里执行
public String call(String functionArguments) {
I request = this.fromJson(functionArguments, this.inputType);
return (String)this.andThen(this.responseConverter).apply(request, (Object)null);
}
第三个 MethodInvokingFunctionCallback
的逻辑如下。如果我们是已有的类,service等,将其中的某个方法注册成tools,则会最终执行到这里。
public String call(String functionInput, ToolContext toolContext) {
try {
if (toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()) && !this.isToolContextMethod) {
throw new IllegalArgumentException("Configured method does not accept ToolContext as input parameter!");
} else {
Map<String, Object> map = (Map)this.mapper.readValue(functionInput, Map.class);
Object[] methodArgs = Stream.of(this.method.getParameters()).map((parameter) -> {
Class<?> type = parameter.getType();
if (ClassUtils.isAssignable(type, ToolContext.class)) {
return toolContext;
} else {
Object rawValue = map.get(parameter.getName());
return this.toJavaType(rawValue, type);
}
}).toArray();
Object response = ReflectionUtils.invokeMethod(this.method, this.functionObject, methodArgs);
Class<?> returnType = this.method.getReturnType();
if (returnType == Void.TYPE) {
return "Done";
} else {
return returnType != Class.class && !returnType.isRecord() && returnType != List.class && returnType != Map.class ? (String)this.responseConverter.apply(response) : ModelOptionsUtils.toJsonString(response);
}
}
} catch (Exception var7) {
Exception e = var7;
ReflectionUtils.handleReflectionException(e);
return null;
}
}
本次没有写这种示例,大致可以参考官方文档中的内容:
// 1. 已存在的MockOrderService
@Service
public class MockOrderService {
public Response getOrder(Request request) {
String productName = "尤尼克斯羽毛球拍";
return new Response(String.format("%s的订单编号为%s, 购买的商品为: %s", request.userId, request.orderId, productName));
}
@JsonInclude(JsonInclude.Include.NON_NULL)
public record Request(
//这里的JsonProperty将转换为function的parameters信息, 包括参数名称和参数描述等
/*
{
"orderId": {
"type": "string",
"description": "订单编号, 比如1001***"
},
"userId": {
"type": "string",
"description": "用户编号, 比如2001***"
}
}
*/
@JsonProperty(required = true, value = "orderId") @JsonPropertyDescription("订单编号, 比如1001***") String orderId,
@JsonProperty(required = true, value = "userId") @JsonPropertyDescription("用户编号, 比如2001***") String userId) {
}
public record Response(String description) {
}
}
//2. 将MockOrderService的getOrder注册为function call的bean
@Configuration
public class FunctionCallConfiguration {
@Bean
@Description("根据用户编号和订单编号查询订单信息") //function的描述
public Function<MockOrderService.Request, MockOrderService.Response> getOrderFunction(MockOrderService mockOrderService) {
return mockOrderService::getOrder;
}
}
//3. 调用function call
DashScopeChatModel dashscopeChatModel = ...;
ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
.defaultFunctions("getOrderFunction")
.build();
ChatResponse response = chatClient
.prompt()
.user("帮我一下订单, 用户编号为1001, 订单编号为2001")
.call()
.chatResponse();
String content = response.getResult().getOutput().getContent();
logger.info("content: {}", content);