diff --git a/packages/markitdown/src/markitdown/__main__.py b/packages/markitdown/src/markitdown/__main__.py index ccb44b64b..6b1f67c40 100644 --- a/packages/markitdown/src/markitdown/__main__.py +++ b/packages/markitdown/src/markitdown/__main__.py @@ -138,6 +138,22 @@ def main(): help="Keep data URIs (like base64-encoded images) in the output. By default, data URIs are truncated.", ) + parser.add_argument( + "--llm-client", + choices=["openai"], + help="LLM client to use for converters that support vision or OCR. Currently supports 'openai'.", + ) + + parser.add_argument( + "--llm-model", + help="LLM model to pass to converters that use --llm-client.", + ) + + parser.add_argument( + "--llm-prompt", + help="Optional prompt override for converters that use --llm-client.", + ) + parser.add_argument("filename", nargs="?") args = parser.parse_args() @@ -200,6 +216,8 @@ def main(): ) sys.exit(0) + llm_kwargs = _parse_llm_options(args) + if args.use_docintel: if args.endpoint is None: _exit_with_error( @@ -209,7 +227,7 @@ def main(): _exit_with_error("Filename is required when using Document Intelligence.") markitdown = MarkItDown( - enable_plugins=args.use_plugins, docintel_endpoint=args.endpoint + enable_plugins=args.use_plugins, docintel_endpoint=args.endpoint, **llm_kwargs ) elif args.use_cu: if args.cu_endpoint is None: @@ -240,9 +258,9 @@ def main(): _exit_with_error(f"Unknown file type: {name}") cu_kwargs["cu_file_types"] = cu_types - markitdown = MarkItDown(enable_plugins=args.use_plugins, **cu_kwargs) + markitdown = MarkItDown(enable_plugins=args.use_plugins, **llm_kwargs, **cu_kwargs) else: - markitdown = MarkItDown(enable_plugins=args.use_plugins) + markitdown = MarkItDown(enable_plugins=args.use_plugins, **llm_kwargs) if args.filename is None: result = markitdown.convert_stream( @@ -258,6 +276,40 @@ def main(): _handle_output(args, result) +def _parse_llm_options(args) -> dict[str, Any]: + if args.llm_client is None: + if args.llm_model or args.llm_prompt: + _exit_with_error("--llm-model and --llm-prompt require --llm-client.") + return {} + + if not args.llm_model: + _exit_with_error("--llm-client requires --llm-model.") + + llm_kwargs: dict[str, Any] = { + "llm_client": _create_llm_client(args.llm_client), + "llm_model": args.llm_model, + } + if args.llm_prompt: + llm_kwargs["llm_prompt"] = args.llm_prompt + return llm_kwargs + + +def _create_llm_client(client_name: str) -> Any: + if client_name == "openai": + try: + from openai import OpenAI + except ImportError as ex: + _exit_with_error( + "The OpenAI client is required for --llm-client openai. Install it with `pip install openai`." + ) + raise AssertionError("unreachable") from ex + + return OpenAI() + + _exit_with_error(f"Unsupported LLM client: {client_name}") + raise AssertionError("unreachable") + + def _handle_output(args, result: DocumentConverterResult): """Handle output to stdout or file""" if args.output: diff --git a/packages/markitdown/tests/test_cli_misc.py b/packages/markitdown/tests/test_cli_misc.py index cf6c9ccc7..25fc4fa83 100644 --- a/packages/markitdown/tests/test_cli_misc.py +++ b/packages/markitdown/tests/test_cli_misc.py @@ -1,5 +1,10 @@ #!/usr/bin/env python3 -m pytest import subprocess +import sys +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest from markitdown import __version__ # This file contains CLI tests that are not directly tested by the FileTestVectors. @@ -8,7 +13,7 @@ def test_version() -> None: result = subprocess.run( - ["python", "-m", "markitdown", "--version"], capture_output=True, text=True + [sys.executable, "-m", "markitdown", "--version"], capture_output=True, text=True ) assert result.returncode == 0, f"CLI exited with error: {result.stderr}" @@ -17,7 +22,7 @@ def test_version() -> None: def test_invalid_flag() -> None: result = subprocess.run( - ["python", "-m", "markitdown", "--foobar"], capture_output=True, text=True + [sys.executable, "-m", "markitdown", "--foobar"], capture_output=True, text=True ) assert result.returncode != 0, f"CLI exited with error: {result.stderr}" @@ -27,6 +32,58 @@ def test_invalid_flag() -> None: assert "SYNTAX" in result.stderr, "Expected 'SYNTAX' to appear in STDERR" +def test_llm_cli_options_are_passed_to_markitdown(monkeypatch, capsys) -> None: + import markitdown.__main__ as markitdown_cli + + llm_client = object() + markitdown_instance = Mock() + markitdown_instance.convert.return_value = SimpleNamespace(markdown="converted") + monkeypatch.setattr( + sys, + "argv", + [ + "markitdown", + "document.pdf", + "--use-plugins", + "--llm-client", + "openai", + "--llm-model", + "gpt-4o", + "--llm-prompt", + "Extract the text.", + ], + ) + + with ( + patch.object(markitdown_cli, "_create_llm_client", return_value=llm_client) as create_llm_client, + patch.object(markitdown_cli, "MarkItDown", return_value=markitdown_instance) as markitdown_cls, + ): + markitdown_cli.main() + + create_llm_client.assert_called_once_with("openai") + markitdown_cls.assert_called_once_with( + enable_plugins=True, + llm_client=llm_client, + llm_model="gpt-4o", + llm_prompt="Extract the text.", + ) + markitdown_instance.convert.assert_called_once_with( + "document.pdf", stream_info=None, keep_data_uris=False + ) + assert capsys.readouterr().out.strip() == "converted" + + +def test_llm_model_requires_llm_client(monkeypatch, capsys) -> None: + import markitdown.__main__ as markitdown_cli + + monkeypatch.setattr(sys, "argv", ["markitdown", "document.pdf", "--llm-model", "gpt-4o"]) + + with pytest.raises(SystemExit): + markitdown_cli.main() + + assert "--llm-model and --llm-prompt require --llm-client" in capsys.readouterr().out + + if __name__ == "__main__": """Runs this file's tests from the command line.""" test_version()