diff --git a/google_auth_oauthlib/flow.py b/google_auth_oauthlib/flow.py index 9c52711..bc014a6 100644 --- a/google_auth_oauthlib/flow.py +++ b/google_auth_oauthlib/flow.py @@ -410,8 +410,9 @@ def run_local_server( in the user's browser. redirect_uri_trailing_slash (bool): whether or not to add trailing slash when constructing the redirect_uri. Default value is True. - timeout_seconds (int): It will raise an error after the timeout timing - if there are no credentials response. The value is in seconds. + timeout_seconds (int): It will raise a WSGITimeoutError exception after the + timeout timing if there are no credentials response. The value is in + seconds. When set to None there is no timeout. Default value is None. token_audience (str): Passed along with the request for an access @@ -425,6 +426,10 @@ def run_local_server( Returns: google.oauth2.credentials.Credentials: The OAuth 2.0 credentials for the user. + + Raises: + WSGITimeoutError: If there is a timeout when waiting for the response from the + authorization server. """ wsgi_app = _RedirectWSGIApp(success_message) # Fail fast if the address is occupied @@ -455,7 +460,15 @@ def run_local_server( # Note: using https here because oauthlib is very picky that # OAuth 2.0 should only occur over https. - authorization_response = wsgi_app.last_request_uri.replace("http", "https") + try: + authorization_response = wsgi_app.last_request_uri.replace( + "http", "https" + ) + except AttributeError as e: + raise WSGITimeoutError( + "Timed out waiting for response from authorization server" + ) from e + self.fetch_token( authorization_response=authorization_response, audience=token_audience ) @@ -506,3 +519,7 @@ def __call__(self, environ, start_response): start_response("200 OK", [("Content-type", "text/plain; charset=utf-8")]) self.last_request_uri = wsgiref.util.request_uri(environ) return [self._success_message.encode("utf-8")] + + +class WSGITimeoutError(AttributeError): + """Raised when the WSGI server times out waiting for a response.""" diff --git a/tests/unit/test_flow.py b/tests/unit/test_flow.py index 8f5b561..d8be492 100644 --- a/tests/unit/test_flow.py +++ b/tests/unit/test_flow.py @@ -497,3 +497,20 @@ def test_run_local_server_logs_and_prints_url( urllib.parse.quote(instance.redirect_uri, safe="") in print_mock.call_args[0][0] ) + + @mock.patch("google_auth_oauthlib.flow.webbrowser", autospec=True) + @mock.patch("wsgiref.simple_server.make_server", autospec=True) + def test_run_local_server_timeout( + self, make_server_mock, webbrowser_mock, instance, mock_fetch_token + ): + mock_server = mock.Mock() + make_server_mock.return_value = mock_server + + # handle_request does nothing (simulating timeout), so last_request_uri remains None + mock_server.handle_request.return_value = None + + with pytest.raises(flow.WSGITimeoutError): + instance.run_local_server(timeout_seconds=1) + + webbrowser_mock.get.assert_called_with(None) + webbrowser_mock.get.return_value.open.assert_called_once()