Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/dstack/_internal/core/services/diff.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Any, Optional, TypedDict, TypeVar
from typing import Any, Optional, TypeVar, Union

from pydantic import BaseModel

from dstack._internal.core.models.common import IncludeExcludeType
from dstack._internal.core.models.common import CoreModel, IncludeExcludeType


class ModelFieldDiff(TypedDict):
class ModelFieldDiff(CoreModel):
old: Any
new: Any


ModelDiff = dict[str, ModelFieldDiff]
ModelDiff = dict[str, Union[ModelFieldDiff, "ModelDiff"]]


# TODO: calculate nested diffs
Expand Down Expand Up @@ -45,7 +45,7 @@ def diff_models(
old_value = getattr(old, field)
new_value = getattr(new, field)
if old_value != new_value:
changes[field] = {"old": old_value, "new": new_value}
changes[field] = ModelFieldDiff(old=old_value, new=new_value)

return changes

Expand All @@ -69,3 +69,26 @@ def copy_model(model: M, reset: Optional[IncludeExcludeType] = None) -> M:
A deep copy of the model instance.
"""
return type(model).parse_obj(model.dict(exclude=reset))


def flatten_diff_fields(diff: ModelDiff, prefix: str = "") -> list[str]:
"""
Recursively collects all field paths from a diff.

Returns:
A list of field paths, each path with dot-separated parts.
"""
fields = []
for field_name, field_diff in diff.items():
current_path = f"{prefix}.{field_name}" if prefix else field_name

if isinstance(field_diff, ModelFieldDiff):
fields.append(current_path)
else:
fields.extend(flatten_diff_fields(field_diff, current_path))

return fields


def format_diff_fields_for_event(diff: ModelDiff) -> str:
return ", ".join(flatten_diff_fields(diff))
34 changes: 27 additions & 7 deletions src/dstack/_internal/server/services/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
RunTerminationReason,
ServiceSpec,
)
from dstack._internal.core.services.diff import format_diff_fields_for_event
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
FleetModel,
Expand Down Expand Up @@ -258,11 +259,11 @@ async def get_run(
raise ServerClientError("run_name or id must be specified")


async def get_run_by_name(
async def get_run_model_by_name(
session: AsyncSession,
project: ProjectModel,
run_name: str,
) -> Optional[Run]:
) -> Optional[RunModel]:
res = await session.execute(
select(RunModel)
.where(
Expand All @@ -274,7 +275,15 @@ async def get_run_by_name(
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(selectinload(RunModel.jobs).joinedload(JobModel.probes))
)
run_model = res.scalar()
return res.scalar()


async def get_run_by_name(
session: AsyncSession,
project: ProjectModel,
run_name: str,
) -> Optional[Run]:
run_model = await get_run_model_by_name(session=session, project=project, run_name=run_name)
if run_model is None:
return None
return run_model_to_run(run_model, return_in_api=True)
Expand Down Expand Up @@ -386,24 +395,25 @@ async def apply_plan(
project=project,
run_spec=run_spec,
)
current_resource = await get_run_by_name(
current_resource_model = await get_run_model_by_name(
session=session,
project=project,
run_name=run_spec.run_name,
)
if current_resource is None or current_resource.status.is_finished():
if current_resource_model is None or current_resource_model.status.is_finished():
return await submit_run(
session=session,
user=user,
project=project,
run_spec=run_spec,
)
current_resource = run_model_to_run(current_resource_model, return_in_api=True)

# For backward compatibility (current_resource may has been submitted before
# some fields, e.g., CPUSpec.arch, were added)
set_resources_defaults(current_resource.run_spec.configuration.resources)
try:
check_can_update_run_spec(current_resource.run_spec, run_spec)
spec_diff = check_can_update_run_spec(current_resource.run_spec, run_spec)
except ServerClientError:
# The except is only needed to raise an appropriate error if run is active
if not current_resource.status.is_finished():
Expand All @@ -420,6 +430,7 @@ async def apply_plan(
raise ServerClientError(
"Failed to apply plan. Resource has been changed. Try again or use force apply."
)
new_deployment_num = current_resource.deployment_num + 1
# FIXME: potentially long write transaction
# Avoid getting run_model after update
await session.execute(
Expand All @@ -428,9 +439,18 @@ async def apply_plan(
.values(
run_spec=run_spec.json(),
priority=run_spec.configuration.priority,
deployment_num=current_resource.deployment_num + 1,
deployment_num=new_deployment_num,
)
)
events.emit(
session,
(
f"Run updated. Deployment: {new_deployment_num}."
f" Changed fields: {format_diff_fields_for_event(spec_diff)}"
),
actor=events.UserActor.from_user(user),
targets=[events.Target.from_model(current_resource_model)],
)
run = await get_run_by_name(
session=session,
project=project,
Expand Down
22 changes: 18 additions & 4 deletions src/dstack/_internal/server/services/runs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dstack._internal.core.models.runs import LEGACY_REPO_DIR, AnyRunConfiguration, RunSpec
from dstack._internal.core.models.volumes import InstanceMountPoint
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.core.services.diff import diff_models
from dstack._internal.core.services.diff import ModelDiff, diff_models
from dstack._internal.server import settings
from dstack._internal.server.models import UserModel
from dstack._internal.server.services.docker import is_valid_docker_volume_target
Expand Down Expand Up @@ -117,7 +117,13 @@ def validate_run_spec_and_set_defaults(
run_spec.configuration.working_dir = LEGACY_REPO_DIR


def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> ModelDiff:
"""
Check if in-place update is possible.

Returns the diff if it is possible.
Raises ServerClientError otherwise.
"""
spec_diff = diff_models(current_run_spec, new_run_spec)
changed_spec_fields = list(spec_diff.keys())
updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get(
Expand All @@ -133,9 +139,10 @@ def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
# are the same (the same id => hash => content and the same container path), the order of
# unpacking matters when one path is a subpath of another.
ignore_files = current_run_spec.file_archives == new_run_spec.file_archives
_check_can_update_configuration(
spec_diff["configuration"] = _check_can_update_configuration(
current_run_spec.configuration, new_run_spec.configuration, ignore_files
)
return spec_diff


def can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bool:
Expand Down Expand Up @@ -167,7 +174,13 @@ def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:

def _check_can_update_configuration(
current: AnyRunConfiguration, new: AnyRunConfiguration, ignore_files: bool
) -> None:
) -> ModelDiff:
"""
Check if in-place update is possible.

Returns the diff if it is possible.
Raises ServerClientError otherwise.
"""
if current.type != new.type:
raise ServerClientError(
f"Configuration type changed from {current.type} to {new.type}, cannot update"
Expand All @@ -189,3 +202,4 @@ def _check_can_update_configuration(
raise ServerClientError(
f"Failed to update fields {changed_fields}. Can only update {updatable_fields}"
)
return diff
59 changes: 59 additions & 0 deletions src/tests/_internal/core/services/test_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from dstack._internal.core.services.diff import ModelDiff, ModelFieldDiff, flatten_diff_fields


@pytest.mark.parametrize(
"diff,expected",
[
pytest.param({}, [], id="empty_diff"),
pytest.param(
{
"field1": ModelFieldDiff(old="old1", new="new1"),
"field2": ModelFieldDiff(old="old2", new="new2"),
},
["field1", "field2"],
id="multiple_fields",
),
pytest.param(
{
"field1": ModelFieldDiff(old="old1", new="new1"),
"nested": {
"sub1": ModelFieldDiff(old="old_sub1", new="new_sub1"),
},
},
["field1", "nested.sub1"],
id="nested_single_level",
),
pytest.param(
{
"field1": ModelFieldDiff(old="old1", new="new1"),
"level1": {
"level2": {
"level3": {"deep_field": ModelFieldDiff(old="deep_old", new="deep_new")},
"field2": ModelFieldDiff(old="old2", new="new2"),
},
"field3": ModelFieldDiff(old="old3", new="new3"),
},
},
["field1", "level1.level2.level3.deep_field", "level1.level2.field2", "level1.field3"],
id="nested_multiple_levels",
),
pytest.param(
{
"field1": ModelFieldDiff(old="old1", new="new1"),
"empty_nested": {},
"nested_with_empty": {
"empty_sub": {},
"field2": ModelFieldDiff(old="old2", new="new2"),
},
},
["field1", "nested_with_empty.field2"],
id="empty_nested",
),
],
)
def test_flatten_diff_fields(diff: ModelDiff, expected: list[str]):
"""Test flatten_diff_fields with various diff structures."""
result = flatten_diff_fields(diff)
assert result == expected
12 changes: 12 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from dstack._internal.server.schemas.runs import ApplyRunPlanRequest
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.services.runs import run_model_to_run
from dstack._internal.server.services.runs.spec import validate_run_spec_and_set_defaults
from dstack._internal.server.testing.common import (
create_backend,
create_fleet,
Expand Down Expand Up @@ -1549,6 +1550,7 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(
run_name="test-service",
configuration_path="old.dstack.yml",
repo_id=repo.name,
configuration=ServiceConfiguration(
type="service",
Expand All @@ -1557,6 +1559,8 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl
replicas=Range(min=1, max=1),
),
)
# set defaults to avoid phantom changes being detected
validate_run_spec_and_set_defaults(user, run_spec)
run_model = await create_run(
session=session,
project=project,
Expand All @@ -1566,6 +1570,7 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl
run_spec=run_spec,
)
run = run_model_to_run(run_model)
run_spec.configuration_path = "new.dstack.yml"
run_spec.configuration.replicas = Range(min=2, max=2)
response = await client.post(
f"/api/project/{project.name}/runs/apply",
Expand All @@ -1586,8 +1591,15 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl
updated_run = run_model_to_run(run_model)
assert run.deployment_num == 0
assert updated_run.deployment_num == 1
assert run.run_spec.configuration_path == "old.dstack.yml"
assert updated_run.run_spec.configuration_path == "new.dstack.yml"
assert run.run_spec.configuration.replicas == Range(min=1, max=1)
assert updated_run.run_spec.configuration.replicas == Range(min=2, max=2)
events = await list_events(session)
assert len(events) == 1
assert events[0].message == (
"Run updated. Deployment: 1. Changed fields: configuration_path, configuration.replicas"
)

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down