Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 37 additions & 26 deletions dspace_rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from requests import Request
import pysolr
import smart_open
from typing import cast, IO

from .models import (
SimpleDSpaceObject,
Expand Down Expand Up @@ -92,7 +93,6 @@ class DSpaceClient:
"""

# Set up basic environment, variables
session = None
API_ENDPOINT = "http://localhost:8080/server/api"
SOLR_ENDPOINT = "http://localhost:8983/solr"
SOLR_AUTH = None
Expand All @@ -112,7 +112,6 @@ class DSpaceClient:
SOLR_AUTH = os.environ["SOLR_AUTH"]
if "USER_AGENT" in os.environ:
USER_AGENT = os.environ["USER_AGENT"]
verbose = False
ITER_PAGE_SIZE = 20
PROXY_DICT = dict(http=os.environ["PROXY_URL"],https=os.environ["PROXY_URL"]) if "PROXY_URL" in os.environ else dict()

Expand All @@ -123,6 +122,7 @@ class PatchOperation:
REPLACE = "replace"
MOVE = "move"

@staticmethod
def paginated(embed_name, item_constructor, embedding=lambda x: x):
"""
@param embed_name: The key under '_embedded' in the JSON response that contains the
Expand Down Expand Up @@ -153,7 +153,7 @@ def do_paginate(url, params):
else:
url = None

return fun(do_paginate, self, *args, **kwargs)
return fun(self, do_paginate, *args, **kwargs)

return decorated
return decorator
Expand Down Expand Up @@ -284,7 +284,7 @@ def authenticate(self, retry=False):
# Update headers with new bearer token if present
if "Authorization" in r.headers:
self.session.headers.update(
{"Authorization": r.headers.get("Authorization")}
{"Authorization": r.headers["Authorization"]}
)

# Get and check authentication status
Expand All @@ -294,7 +294,7 @@ def authenticate(self, retry=False):
)
if r.status_code == 200:
r_json = parse_json(r)
if "authenticated" in r_json and r_json["authenticated"] is True:
if r_json is not None and "authenticated" in r_json and r_json["authenticated"] is True:
logging.info("Authenticated successfully as %s", self.USERNAME)
return r_json["authenticated"]

Expand Down Expand Up @@ -503,6 +503,8 @@ def search_objects(

r_json = self.fetch_resource(url=url, params={**params, **filters})

if r_json is None:
return dsos
# instead lots of 'does this key exist, etc etc' checks, just go for it and wrap in a try?
try:
results = r_json["_embedded"]["searchResult"]["_embedded"]["objects"]
Expand All @@ -523,8 +525,8 @@ def search_objects(
embedding=lambda x: x["_embedded"]["searchResult"],
)
def search_objects_iter(
do_paginate,
self,
do_paginate,
query=None,
scope=None,
filters=None,
Expand Down Expand Up @@ -611,8 +613,10 @@ def create_dso(self, url, params, data, embeds=None):
if r.status_code == 201:
# 201 Created - success!
new_dso = parse_json(r)
if new_dso is None:
return r
logging.info(
"%s %s created successfully!", new_dso["type"], new_dso["uuid"]
"%s %s created successfully!", new_dso.get("type"), new_dso.get("uuid")
)
else:
logging.error(
Expand Down Expand Up @@ -702,7 +706,7 @@ def delete_dso(self, dso=None, url=None, params=None):
)
return None
except ValueError as e:
logging.error("Error deleting DSO %s: %s", dso.uuid, e)
logging.error("Error deleting DSO %s: %s", url, e)
return None

# PAGINATION
Expand Down Expand Up @@ -739,7 +743,7 @@ def get_bundles(
try:
if single_result:
bundles.append(Bundle(r_json))
if not single_result:
if not single_result and r_json is not None:
resources = r_json["_embedded"]["bundles"]
for resource in resources:
bundles.append(Bundle(resource))
Expand All @@ -749,7 +753,7 @@ def get_bundles(
return bundles

@paginated("bundles", Bundle)
def get_bundles_iter(do_paginate, self, parent, sort=None, embeds=None):
def get_bundles_iter(self, do_paginate, parent, sort=None, embeds=None):
"""
Get bundles for an item, automatically handling pagination by requesting the next page when all items from one page have been consumed
@param parent: python Item object, from which the UUID will be referenced in the URL.
Expand Down Expand Up @@ -825,7 +829,7 @@ def get_bitstreams(
params["sort"] = sort

r_json = self.fetch_resource(url, params=params)
if "_embedded" in r_json:
if r_json is not None and "_embedded" in r_json:
if "bitstreams" in r_json["_embedded"]:
bitstreams = []
for bitstream_resource in r_json["_embedded"]["bitstreams"]:
Expand All @@ -834,7 +838,7 @@ def get_bitstreams(
return bitstreams

@paginated("bitstreams", Bitstream)
def get_bitstreams_iter(do_paginate, self, bundle, sort=None, embeds=None):
def get_bitstreams_iter(self, do_paginate, bundle, sort=None, embeds=None):
"""
Get all bitstreams for a specific bundle, automatically handling pagination by requesting the next page when all items from one page have been consumed
@param bundle: A python Bundle object to parse for bitstream links to retrieve
Expand Down Expand Up @@ -891,10 +895,14 @@ def create_bitstream(
# TODO: Better error detection and handling for file reading
if metadata is None:
metadata = {}
if bundle is None:
logging.error("Cannot create bitstream without bundle")
return None

url = f"{self.API_ENDPOINT}/core/bundles/{bundle.uuid}/bitstreams"

try:
with smart_open.open(path, "rb") as file_obj:
with cast(IO[bytes], smart_open.open(path, "rb")) as file_obj:
file = (name, file_obj.read(), mime)
files = {"file": file}
properties = {"name": name, "metadata": metadata, "bundleName": bundle.name}
Expand Down Expand Up @@ -923,7 +931,7 @@ def create_bitstream(
# we should enhance self.api_post to be able to send files and use our decorators
if r.status_code == 403:
r_json = parse_json(r)
if "message" in r_json and "CSRF token" in r_json["message"]:
if r_json is not None and "message" in r_json and "CSRF token" in r_json["message"]:
if retry:
logging.error("Already retried... something must be wrong")
else:
Expand Down Expand Up @@ -1000,18 +1008,18 @@ def get_communities(
r_json = self.fetch_resource(url, params)
# Empty list
communities = []
if "_embedded" in r_json:
if r_json is not None and "_embedded" in r_json:
if "communities" in r_json["_embedded"]:
for community_resource in r_json["_embedded"]["communities"]:
communities.append(Community(community_resource))
elif "uuid" in r_json:
elif r_json is not None and "uuid" in r_json:
# This is a single communities
communities.append(Community(r_json))
# Return list (populated or empty)
return communities

@paginated("communities", Community)
def get_communities_iter(do_paginate, self, sort=None, top=False, embeds=None):
def get_communities_iter(self, do_paginate, sort=None, top=False, embeds=None):
"""
Get communities as an iterator, automatically handling pagination by requesting the next page when all items from one page have been consumed
@param top: whether to restrict search to top communities (default: false)
Expand Down Expand Up @@ -1089,20 +1097,20 @@ def get_collections(
r_json = self.fetch_resource(url, params=params)
# Empty list
collections = []
if "_embedded" in r_json:
if r_json is not None and "_embedded" in r_json:
# This is a list of collections
if "collections" in r_json["_embedded"]:
for collection_resource in r_json["_embedded"]["collections"]:
collections.append(Collection(collection_resource))
elif "uuid" in r_json:
elif r_json is not None and "uuid" in r_json:
# This is a single collection
collections.append(Collection(r_json))

# Return list (populated or empty)
return collections

@paginated("collections", Collection)
def get_collections_iter(do_paginate, self, community=None, sort=None, embeds=None):
def get_collections_iter(self, do_paginate, community=None, sort=None, embeds=None):
"""
Get collections as an iterator, automatically handling pagination by requesting the next page when all items from one page have been consumed
@param community: Community object. If present, collections for a community
Expand Down Expand Up @@ -1167,12 +1175,12 @@ def get_items(self, embeds=None):
r_json = self.fetch_resource(url, params=parse_params(embeds=embeds))
# Empty list
items = []
if "_embedded" in r_json:
if r_json is not None and "_embedded" in r_json:
# This is a list of items
if "items" in r_json["_embedded"]:
for item_resource in r_json["_embedded"]["items"]:
items.append(Item(item_resource))
elif "uuid" in r_json:
elif r_json is not None and "uuid" in r_json:
# This is a single item
items.append(Item(r_json))

Expand Down Expand Up @@ -1355,14 +1363,14 @@ def get_users(self, page=0, size=20, sort=None, embeds=None):
params["sort"] = sort
r = self.api_get(url, params=params)
r_json = parse_json(response=r)
if "_embedded" in r_json:
if r_json is not None and "_embedded" in r_json:
if "epersons" in r_json["_embedded"]:
for user_resource in r_json["_embedded"]["epersons"]:
users.append(User(user_resource))
return users

@paginated("epersons", User)
def get_users_iter(do_paginate, self, sort=None, embeds=None):
def get_users_iter(self, do_paginate, sort=None, embeds=None):
"""
Get an iterator of users (epersons) in the DSpace instance, automatically handling pagination by requesting the next page when all items from one page have been consumed
@param sort: Optional sort parameter
Expand All @@ -1377,7 +1385,7 @@ def get_users_iter(do_paginate, self, sort=None, embeds=None):
return do_paginate(url, params)

@paginated("groups", Group)
def search_groups_by_metadata_iter(do_paginate, self, query, embeds=None):
def search_groups_by_metadata_iter(self, do_paginate, query, embeds=None):
"""
Search for groups by metadata
@param query: Search query (UUID or group name)
Expand Down Expand Up @@ -1495,7 +1503,7 @@ def resolve_identifier_to_dso(self, identifier=None):
logging.error(f"Error resolving identifier {identifier} to DSO: {r.status_code}")

@paginated("resourcepolicies", ResourcePolicy)
def get_resource_policies_iter(do_paginate, self, parent=None, action=None, embeds=None):
def get_resource_policies_iter(self, do_paginate, parent=None, action=None, embeds=None):
"""
Get resource policies (as an iterator) for a given parent object and action
@param parent: UUID of an object to which the policy applies
Expand Down Expand Up @@ -1544,6 +1552,9 @@ def create_resource_policy(self, resource_policy, parent=None, eperson=None, gro
if r.status_code == 200 or r.status_code == 201:
# 200 OK or 201 Created means Created - success! (201 is used now, 200 perhaps in teh past?)
new_policy = parse_json(r)
if new_policy is None:
logging.error("Response containing new resource policy is empty or invalid")
return None
logging.info("%s %s created successfully!",
new_policy["type"], new_policy["id"])
return ResourcePolicy(api_resource=new_policy)
Expand Down
Loading