Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,58 @@ async def test_update_auto_now(db):
assert obj1.updated_at.date() == updated_at.date()


@pytest.mark.asyncio
async def test_update_auto_now_with_update_fields(db):
tournament = await Tournament.create(name="1")
event = await Event.create(name="original", tournament=tournament)
original_modified = event.modified

# Set modified to the past so we can detect if it gets updated
past = timezone.now() - timedelta(days=1)
await Event.filter(pk=event.pk).update(modified=past)

event = await Event.get(pk=event.pk)
assert event.modified.date() == past.date()

# Update only name with update_fields; auto_now field should also be updated
event.name = "updated"
await event.save(update_fields=["name"])

event = await Event.get(pk=event.pk)
assert event.name == "updated"
assert event.modified.date() == timezone.now().date()


@pytest.mark.asyncio
async def test_bulk_update_auto_now(db):
tournament = await Tournament.create(name="1")
event1 = await Event.create(name="original1", tournament=tournament)
event2 = await Event.create(name="original2", tournament=tournament)

# Set modified to the past so we can detect if it gets updated
past = timezone.now() - timedelta(days=1)
await Event.filter(pk__in=[event1.pk, event2.pk]).update(modified=past)

event1 = await Event.get(pk=event1.pk)
event2 = await Event.get(pk=event2.pk)
assert event1.modified.date() == past.date()
assert event2.modified.date() == past.date()

# bulk_update only name; auto_now field should also be updated
event1.name = "updated1"
event2.name = "updated2"
await Event.filter(pk__in=[event1.pk, event2.pk]).bulk_update(
[event1, event2], fields=["name"]
)

event1 = await Event.get(pk=event1.pk)
event2 = await Event.get(pk=event2.pk)
assert event1.name == "updated1"
assert event2.name == "updated2"
assert event1.modified.date() == timezone.now().date()
assert event2.modified.date() == timezone.now().date()


@pytest.mark.asyncio
async def test_update_relation(db):
tournament_first = await Tournament.create(name="1")
Expand Down
6 changes: 6 additions & 0 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,12 @@ async def save(
raise IncompleteInstanceError(
f"{self.__class__.__name__} is a partial model, can only be saved with the relevant update_field provided"
)
if update_fields:
update_fields = list(update_fields)
for field_name, field_obj in self._meta.fields_map.items():
if field_name not in update_fields and getattr(field_obj, "auto_now", False):
update_fields.append(field_name)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about queryset .update() and .bulk_update() methods?


await self._pre_save(db, update_fields)

if force_create:
Expand Down
6 changes: 5 additions & 1 deletion tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,7 +1933,11 @@ def __init__(
limit=limit,
orderings=orderings,
)
self.fields = fields
fields_list = list(fields)
for field_name, field_obj in model._meta.fields_map.items():
if field_name not in fields_list and getattr(field_obj, "auto_now", False):
fields_list.append(field_name)
self.fields = fields_list
self._objects = objects
self._batch_size = batch_size
self._queries: list[QueryBuilder] = []
Expand Down