Skip to content

Commit 0d66f5a

Browse files
authored
feat: add react agent example (#393)
1 parent dc3a603 commit 0d66f5a

15 files changed

Lines changed: 841 additions & 0 deletions

File tree

spring-ai-alibaba-agent-example/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
<modules>
3939
<module>playground-flight-booking</module>
40+
<module>react-agent-example</module>
4041
</modules>
4142

4243
<build>
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/mvnw text eol=lf
2+
*.cmd text eol=crlf
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
HELP.md
2+
target/
3+
.mvn/wrapper/maven-wrapper.jar
4+
!**/src/main/**/target/
5+
!**/src/test/**/target/
6+
7+
### STS ###
8+
.apt_generated
9+
.classpath
10+
.factorypath
11+
.project
12+
.settings
13+
.springBeans
14+
.sts4-cache
15+
16+
### IntelliJ IDEA ###
17+
.idea
18+
*.iws
19+
*.iml
20+
*.ipr
21+
22+
### NetBeans ###
23+
/nbproject/private/
24+
/nbbuild/
25+
/dist/
26+
/nbdist/
27+
/.nb-gradle/
28+
build/
29+
!**/src/main/**/build/
30+
!**/src/test/**/build/
31+
32+
### VS Code ###
33+
.vscode/
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>org.springframework.boot</groupId>
7+
<artifactId>spring-boot-starter-parent</artifactId>
8+
<version>3.5.7</version>
9+
<relativePath/> <!-- lookup parent from repository -->
10+
</parent>
11+
<groupId>com.cloud.alibaba.ai.example</groupId>
12+
<artifactId>react-agent-example</artifactId>
13+
<version>0.0.1-SNAPSHOT</version>
14+
<name>planner-agent-example</name>
15+
<description>planner-agent-example</description>
16+
<url/>
17+
<licenses>
18+
<license/>
19+
</licenses>
20+
<developers>
21+
<developer/>
22+
</developers>
23+
<scm>
24+
<connection/>
25+
<developerConnection/>
26+
<tag/>
27+
<url/>
28+
</scm>
29+
<properties>
30+
<java.version>17</java.version>
31+
<saa.agent.version>1.1.0.0-SNAPSHOT</saa.agent.version>
32+
</properties>
33+
<dependencies>
34+
<dependency>
35+
<groupId>org.springframework.boot</groupId>
36+
<artifactId>spring-boot-starter</artifactId>
37+
</dependency>
38+
39+
<dependency>
40+
<groupId>org.springframework.boot</groupId>
41+
<artifactId>spring-boot-starter-test</artifactId>
42+
<scope>test</scope>
43+
</dependency>
44+
45+
<dependency>
46+
<groupId>org.springframework.boot</groupId>
47+
<artifactId>spring-boot-starter-web</artifactId>
48+
</dependency>
49+
50+
<!-- Spring AI Alibaba Agent Framework -->
51+
<dependency>
52+
<groupId>com.alibaba.cloud.ai</groupId>
53+
<artifactId>spring-ai-alibaba-agent-framework</artifactId>
54+
<version>${saa.agent.version}</version>
55+
</dependency>
56+
57+
<!-- DashScope ChatModel -->
58+
<dependency>
59+
<groupId>com.alibaba.cloud.ai</groupId>
60+
<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
61+
<version>${saa.agent.version}</version>
62+
</dependency>
63+
64+
<dependency>
65+
<groupId>org.springframework.boot</groupId>
66+
<artifactId>spring-boot-starter-thymeleaf</artifactId>
67+
</dependency>
68+
69+
</dependencies>
70+
71+
<build>
72+
<plugins>
73+
<plugin>
74+
<groupId>org.springframework.boot</groupId>
75+
<artifactId>spring-boot-maven-plugin</artifactId>
76+
</plugin>
77+
</plugins>
78+
</build>
79+
80+
</project>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.cloud.alibaba.ai.example;
2+
3+
import org.springframework.boot.SpringApplication;
4+
import org.springframework.boot.autoconfigure.SpringBootApplication;
5+
6+
@SpringBootApplication
7+
public class Application {
8+
9+
public static void main(String[] args) {
10+
SpringApplication.run(Application.class, args);
11+
}
12+
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.cloud.alibaba.ai.example.config;
2+
3+
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
4+
import com.alibaba.cloud.ai.graph.agent.hook.hip.HumanInTheLoopHook;
5+
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
6+
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
7+
import com.cloud.alibaba.ai.example.interceptor.LogToolInterceptor;
8+
import com.cloud.alibaba.ai.example.tools.FileReadTool;
9+
import com.cloud.alibaba.ai.example.tools.FileWriteTool;
10+
import org.springframework.ai.chat.model.ChatModel;
11+
import org.springframework.context.annotation.Bean;
12+
import org.springframework.context.annotation.Configuration;
13+
14+
@Configuration
15+
public class AgentConfiguration {
16+
17+
private final ChatModel chatModel;
18+
19+
public AgentConfiguration(ChatModel chatModel) {
20+
this.chatModel = chatModel;
21+
}
22+
23+
@Bean
24+
public ReactAgent reactAgent() throws GraphStateException {
25+
return ReactAgent.builder()
26+
.name("agent")
27+
.description("This is a react agent")
28+
.model(chatModel)
29+
.saver(new MemorySaver())
30+
.tools(
31+
new FileReadTool().toolCallback(),
32+
new FileWriteTool().toolCallback()
33+
)
34+
.hooks(HumanInTheLoopHook.builder()
35+
.approvalOn("file_write", "Write File should be approved")
36+
.build())
37+
.interceptors(new LogToolInterceptor())
38+
.build();
39+
}
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package com.cloud.alibaba.ai.example.controller;
2+
3+
import com.alibaba.cloud.ai.graph.RunnableConfig;
4+
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
5+
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
6+
import org.springframework.stereotype.Controller;
7+
import org.springframework.web.bind.annotation.GetMapping;
8+
import org.springframework.web.bind.annotation.PostMapping;
9+
import org.springframework.web.bind.annotation.RequestBody;
10+
import org.springframework.web.bind.annotation.RequestParam;
11+
import org.springframework.web.bind.annotation.ResponseBody;
12+
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.concurrent.ConcurrentHashMap;
16+
17+
@Controller
18+
public class AgentController {
19+
20+
private final ReactAgent reactAgent;
21+
22+
private final Map<String, InterruptionMetadata> map = new ConcurrentHashMap<>();
23+
24+
public AgentController(ReactAgent reactAgent) {
25+
this.reactAgent = reactAgent;
26+
}
27+
28+
@GetMapping("/invoke")
29+
@ResponseBody
30+
public List<InterruptionMetadata.ToolFeedback> invoke(@RequestParam("query") String query,
31+
@RequestParam("threadId") String threadId
32+
) throws Exception {
33+
RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();
34+
InterruptionMetadata metadata = (InterruptionMetadata) reactAgent.invokeAndGetOutput(query, runnableConfig).orElseThrow();
35+
map.put(threadId, metadata);
36+
return metadata.toolFeedbacks();
37+
}
38+
39+
@PostMapping("/feedback")
40+
@ResponseBody
41+
public String feedback(@RequestBody List<Feedback> feedbacks,
42+
@RequestParam("threadId") String threadId
43+
) throws Exception {
44+
InterruptionMetadata metadata = map.get(threadId);
45+
if(metadata == null) {
46+
return "no metadata found";
47+
}
48+
if(metadata.toolFeedbacks().size() != feedbacks.size()) {
49+
return "feedback size not match";
50+
}
51+
52+
InterruptionMetadata.Builder newBuilder = InterruptionMetadata.builder()
53+
.nodeId(metadata.node())
54+
.state(metadata.state());
55+
for (int i = 0; i < feedbacks.size(); i++) {
56+
var toolFeedback = metadata.toolFeedbacks().get(i);
57+
InterruptionMetadata.ToolFeedback.Builder editedFeedbackBuilder = InterruptionMetadata.ToolFeedback
58+
.builder(toolFeedback);
59+
if(feedbacks.get(i).isApproved()) {
60+
editedFeedbackBuilder.result(InterruptionMetadata.ToolFeedback.FeedbackResult.APPROVED);
61+
} else {
62+
editedFeedbackBuilder.result(InterruptionMetadata.ToolFeedback.FeedbackResult.REJECTED)
63+
.description(feedbacks.get(i).feedback());
64+
}
65+
newBuilder.addToolFeedback(editedFeedbackBuilder.build());
66+
}
67+
RunnableConfig resumeRunnableConfig = RunnableConfig.builder().threadId(threadId)
68+
.addMetadata(RunnableConfig.HUMAN_FEEDBACK_METADATA_KEY, newBuilder.build())
69+
.build();
70+
reactAgent.invokeAndGetOutput("", resumeRunnableConfig);
71+
return "success";
72+
}
73+
74+
@GetMapping
75+
public String index() {
76+
return "index";
77+
}
78+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package com.cloud.alibaba.ai.example.controller;
2+
3+
public record Feedback(boolean isApproved, String feedback) {
4+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.cloud.alibaba.ai.example.interceptor;
2+
3+
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallHandler;
4+
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallRequest;
5+
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolCallResponse;
6+
import com.alibaba.cloud.ai.graph.agent.interceptor.ToolInterceptor;
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
10+
11+
public class LogToolInterceptor extends ToolInterceptor {
12+
13+
private static final Logger log = LoggerFactory.getLogger(LogToolInterceptor.class);
14+
15+
@Override
16+
public ToolCallResponse interceptToolCall(ToolCallRequest request, ToolCallHandler handler) {
17+
log.info("ToolInterceptor: Tool {} is called!", request.getToolName());
18+
return handler.call(request);
19+
}
20+
21+
@Override
22+
public String getName() {
23+
return "LogToolInterceptor";
24+
}
25+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.cloud.alibaba.ai.example.tools;
2+
3+
import com.fasterxml.jackson.annotation.JsonClassDescription;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
6+
import org.springframework.ai.chat.model.ToolContext;
7+
import org.springframework.ai.tool.ToolCallback;
8+
import org.springframework.ai.tool.function.FunctionToolCallback;
9+
10+
import java.io.IOException;
11+
import java.nio.file.Files;
12+
import java.nio.file.Path;
13+
14+
public class FileReadTool implements Tool<FileReadTool.Request, String> {
15+
@Override
16+
public ToolCallback toolCallback() {
17+
return FunctionToolCallback.builder("file_read", this)
18+
.description("Tool for read files. ")
19+
.inputType(Request.class)
20+
.build();
21+
}
22+
23+
@Override
24+
public String apply(FileReadTool.Request request, ToolContext toolContext) {
25+
try {
26+
return Files.readString(Path.of(request.filePath));
27+
} catch (IOException e) {
28+
return "Error reading file: " + e.getMessage();
29+
}
30+
}
31+
32+
@JsonClassDescription("Request for the FileReadTool")
33+
public record Request(
34+
@JsonProperty(value = "file_path", required = true)
35+
@JsonPropertyDescription("The path of the file to read")
36+
String filePath
37+
) {}
38+
}

0 commit comments

Comments
 (0)