diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index 5ce953e545..6981c66c52 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -74,6 +74,26 @@ class GitShow(BaseModel): +class GitPush(BaseModel): + repo_path: str + remote: str = Field(default="origin", description="Name of the remote to push to") + branch: str = Field(default="", description="Branch to push. Defaults to current branch") + force: bool = Field(default=False, description="Force push (use with caution)") + + +class GitFetch(BaseModel): + repo_path: str + remote: str = Field(default="origin", description="Name of the remote to fetch from") + prune: bool = Field(default=False, description="Remove remote-tracking branches that no longer exist on the remote") + + +class GitPull(BaseModel): + repo_path: str + remote: str = Field(default="origin", description="Name of the remote to pull from") + branch: str = Field(default="", description="Branch to pull. Defaults to current branch") + rebase: bool = Field(default=False, description="Rebase instead of merge when pulling") + + class GitBranch(BaseModel): repo_path: str = Field( ..., @@ -107,6 +127,9 @@ class GitTools(str, Enum): SHOW = "git_show" BRANCH = "git_branch" + PUSH = "git_push" + FETCH = "git_fetch" + PULL = "git_pull" def git_status(repo: git.Repo) -> str: return repo.git.status() @@ -255,6 +278,42 @@ def validate_repo_path(repo_path: Path, allowed_repository: Path | None) -> None ) +def git_push(repo: git.Repo, remote: str = "origin", branch: str = "", force: bool = False) -> str: + # Defense in depth: reject names starting with '-' to prevent flag injection + if remote.startswith("-"): + raise BadName(f"Invalid remote: '{remote}' - cannot start with '-'") + if branch and branch.startswith("-"): + raise BadName(f"Invalid branch: '{branch}' - cannot start with '-'") + + args = ["--force"] if force else [] + target_branch = branch if branch else repo.active_branch.name + info = repo.git.push(remote, target_branch, *args) + return f"Pushed '{target_branch}' to '{remote}' successfully.\n{info}" + + +def git_fetch(repo: git.Repo, remote: str = "origin", prune: bool = False) -> str: + # Defense in depth: reject names starting with '-' to prevent flag injection + if remote.startswith("-"): + raise BadName(f"Invalid remote: '{remote}' - cannot start with '-'") + + args = ["--prune"] if prune else [] + info = repo.git.fetch(remote, *args) + return f"Fetched from '{remote}' successfully.\n{info}" + + +def git_pull(repo: git.Repo, remote: str = "origin", branch: str = "", rebase: bool = False) -> str: + # Defense in depth: reject names starting with '-' to prevent flag injection + if remote.startswith("-"): + raise BadName(f"Invalid remote: '{remote}' - cannot start with '-'") + if branch and branch.startswith("-"): + raise BadName(f"Invalid branch: '{branch}' - cannot start with '-'") + + args = ["--rebase"] if rebase else [] + target_branch = branch if branch else repo.active_branch.name + info = repo.git.pull(remote, target_branch, *args) + return f"Pulled '{target_branch}' from '{remote}' successfully.\n{info}" + + def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, not_contains: str | None = None) -> str: # Defense in depth: reject values starting with '-' to prevent flag injection if contains and contains.startswith("-"): @@ -437,7 +496,40 @@ async def list_tools() -> list[Tool]: idempotentHint=True, openWorldHint=False, ), - ) + ), + Tool( + name=GitTools.PUSH, + description="Push commits to a remote repository", + inputSchema=GitPush.model_json_schema(), + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), + ), + Tool( + name=GitTools.FETCH, + description="Download objects and refs from a remote repository without merging", + inputSchema=GitFetch.model_json_schema(), + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=True, + openWorldHint=True, + ), + ), + Tool( + name=GitTools.PULL, + description="Fetch from a remote repository and integrate with the current branch", + inputSchema=GitPull.model_json_schema(), + annotations=ToolAnnotations( + readOnlyHint=False, + destructiveHint=False, + idempotentHint=False, + openWorldHint=True, + ), + ), ] async def list_repos() -> Sequence[str]: @@ -579,6 +671,41 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: text=result )] + case GitTools.PUSH: + result = git_push( + repo, + arguments.get("remote", "origin"), + arguments.get("branch", ""), + arguments.get("force", False), + ) + return [TextContent( + type="text", + text=result + )] + + case GitTools.FETCH: + result = git_fetch( + repo, + arguments.get("remote", "origin"), + arguments.get("prune", False), + ) + return [TextContent( + type="text", + text=result + )] + + case GitTools.PULL: + result = git_pull( + repo, + arguments.get("remote", "origin"), + arguments.get("branch", ""), + arguments.get("rebase", False), + ) + return [TextContent( + type="text", + text=result + )] + case _: raise ValueError(f"Unknown tool: {name}")