diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index 0d63cebc43..321d97f5d4 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -1,16 +1,16 @@ -from typing import Any, Optional, TypedDict, TypeVar +from typing import Any, Optional, TypeVar, Union from pydantic import BaseModel -from dstack._internal.core.models.common import IncludeExcludeType +from dstack._internal.core.models.common import CoreModel, IncludeExcludeType -class ModelFieldDiff(TypedDict): +class ModelFieldDiff(CoreModel): old: Any new: Any -ModelDiff = dict[str, ModelFieldDiff] +ModelDiff = dict[str, Union[ModelFieldDiff, "ModelDiff"]] # TODO: calculate nested diffs @@ -45,7 +45,7 @@ def diff_models( old_value = getattr(old, field) new_value = getattr(new, field) if old_value != new_value: - changes[field] = {"old": old_value, "new": new_value} + changes[field] = ModelFieldDiff(old=old_value, new=new_value) return changes @@ -69,3 +69,26 @@ def copy_model(model: M, reset: Optional[IncludeExcludeType] = None) -> M: A deep copy of the model instance. """ return type(model).parse_obj(model.dict(exclude=reset)) + + +def flatten_diff_fields(diff: ModelDiff, prefix: str = "") -> list[str]: + """ + Recursively collects all field paths from a diff. + + Returns: + A list of field paths, each path with dot-separated parts. + """ + fields = [] + for field_name, field_diff in diff.items(): + current_path = f"{prefix}.{field_name}" if prefix else field_name + + if isinstance(field_diff, ModelFieldDiff): + fields.append(current_path) + else: + fields.extend(flatten_diff_fields(field_diff, current_path)) + + return fields + + +def format_diff_fields_for_event(diff: ModelDiff) -> str: + return ", ".join(flatten_diff_fields(diff)) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index f99934dc3f..5ae19b348f 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -37,6 +37,7 @@ RunTerminationReason, ServiceSpec, ) +from dstack._internal.core.services.diff import format_diff_fields_for_event from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( FleetModel, @@ -258,11 +259,11 @@ async def get_run( raise ServerClientError("run_name or id must be specified") -async def get_run_by_name( +async def get_run_model_by_name( session: AsyncSession, project: ProjectModel, run_name: str, -) -> Optional[Run]: +) -> Optional[RunModel]: res = await session.execute( select(RunModel) .where( @@ -274,7 +275,15 @@ async def get_run_by_name( .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) .options(selectinload(RunModel.jobs).joinedload(JobModel.probes)) ) - run_model = res.scalar() + return res.scalar() + + +async def get_run_by_name( + session: AsyncSession, + project: ProjectModel, + run_name: str, +) -> Optional[Run]: + run_model = await get_run_model_by_name(session=session, project=project, run_name=run_name) if run_model is None: return None return run_model_to_run(run_model, return_in_api=True) @@ -386,24 +395,25 @@ async def apply_plan( project=project, run_spec=run_spec, ) - current_resource = await get_run_by_name( + current_resource_model = await get_run_model_by_name( session=session, project=project, run_name=run_spec.run_name, ) - if current_resource is None or current_resource.status.is_finished(): + if current_resource_model is None or current_resource_model.status.is_finished(): return await submit_run( session=session, user=user, project=project, run_spec=run_spec, ) + current_resource = run_model_to_run(current_resource_model, return_in_api=True) # For backward compatibility (current_resource may has been submitted before # some fields, e.g., CPUSpec.arch, were added) set_resources_defaults(current_resource.run_spec.configuration.resources) try: - check_can_update_run_spec(current_resource.run_spec, run_spec) + spec_diff = check_can_update_run_spec(current_resource.run_spec, run_spec) except ServerClientError: # The except is only needed to raise an appropriate error if run is active if not current_resource.status.is_finished(): @@ -420,6 +430,7 @@ async def apply_plan( raise ServerClientError( "Failed to apply plan. Resource has been changed. Try again or use force apply." ) + new_deployment_num = current_resource.deployment_num + 1 # FIXME: potentially long write transaction # Avoid getting run_model after update await session.execute( @@ -428,9 +439,18 @@ async def apply_plan( .values( run_spec=run_spec.json(), priority=run_spec.configuration.priority, - deployment_num=current_resource.deployment_num + 1, + deployment_num=new_deployment_num, ) ) + events.emit( + session, + ( + f"Run updated. Deployment: {new_deployment_num}." + f" Changed fields: {format_diff_fields_for_event(spec_diff)}" + ), + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(current_resource_model)], + ) run = await get_run_by_name( session=session, project=project, diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index ad2fcef1ff..f478d187bb 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -4,7 +4,7 @@ from dstack._internal.core.models.runs import LEGACY_REPO_DIR, AnyRunConfiguration, RunSpec from dstack._internal.core.models.volumes import InstanceMountPoint from dstack._internal.core.services import validate_dstack_resource_name -from dstack._internal.core.services.diff import diff_models +from dstack._internal.core.services.diff import ModelDiff, diff_models from dstack._internal.server import settings from dstack._internal.server.models import UserModel from dstack._internal.server.services.docker import is_valid_docker_volume_target @@ -117,7 +117,13 @@ def validate_run_spec_and_set_defaults( run_spec.configuration.working_dir = LEGACY_REPO_DIR -def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec): +def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> ModelDiff: + """ + Check if in-place update is possible. + + Returns the diff if it is possible. + Raises ServerClientError otherwise. + """ spec_diff = diff_models(current_run_spec, new_run_spec) changed_spec_fields = list(spec_diff.keys()) updatable_spec_fields = _UPDATABLE_SPEC_FIELDS + _TYPE_SPECIFIC_UPDATABLE_SPEC_FIELDS.get( @@ -133,9 +139,10 @@ def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec): # are the same (the same id => hash => content and the same container path), the order of # unpacking matters when one path is a subpath of another. ignore_files = current_run_spec.file_archives == new_run_spec.file_archives - _check_can_update_configuration( + spec_diff["configuration"] = _check_can_update_configuration( current_run_spec.configuration, new_run_spec.configuration, ignore_files ) + return spec_diff def can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bool: @@ -167,7 +174,13 @@ def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool: def _check_can_update_configuration( current: AnyRunConfiguration, new: AnyRunConfiguration, ignore_files: bool -) -> None: +) -> ModelDiff: + """ + Check if in-place update is possible. + + Returns the diff if it is possible. + Raises ServerClientError otherwise. + """ if current.type != new.type: raise ServerClientError( f"Configuration type changed from {current.type} to {new.type}, cannot update" @@ -189,3 +202,4 @@ def _check_can_update_configuration( raise ServerClientError( f"Failed to update fields {changed_fields}. Can only update {updatable_fields}" ) + return diff diff --git a/src/tests/_internal/core/services/test_diff.py b/src/tests/_internal/core/services/test_diff.py new file mode 100644 index 0000000000..4e8d355c0d --- /dev/null +++ b/src/tests/_internal/core/services/test_diff.py @@ -0,0 +1,59 @@ +import pytest + +from dstack._internal.core.services.diff import ModelDiff, ModelFieldDiff, flatten_diff_fields + + +@pytest.mark.parametrize( + "diff,expected", + [ + pytest.param({}, [], id="empty_diff"), + pytest.param( + { + "field1": ModelFieldDiff(old="old1", new="new1"), + "field2": ModelFieldDiff(old="old2", new="new2"), + }, + ["field1", "field2"], + id="multiple_fields", + ), + pytest.param( + { + "field1": ModelFieldDiff(old="old1", new="new1"), + "nested": { + "sub1": ModelFieldDiff(old="old_sub1", new="new_sub1"), + }, + }, + ["field1", "nested.sub1"], + id="nested_single_level", + ), + pytest.param( + { + "field1": ModelFieldDiff(old="old1", new="new1"), + "level1": { + "level2": { + "level3": {"deep_field": ModelFieldDiff(old="deep_old", new="deep_new")}, + "field2": ModelFieldDiff(old="old2", new="new2"), + }, + "field3": ModelFieldDiff(old="old3", new="new3"), + }, + }, + ["field1", "level1.level2.level3.deep_field", "level1.level2.field2", "level1.field3"], + id="nested_multiple_levels", + ), + pytest.param( + { + "field1": ModelFieldDiff(old="old1", new="new1"), + "empty_nested": {}, + "nested_with_empty": { + "empty_sub": {}, + "field2": ModelFieldDiff(old="old2", new="new2"), + }, + }, + ["field1", "nested_with_empty.field2"], + id="empty_nested", + ), + ], +) +def test_flatten_diff_fields(diff: ModelDiff, expected: list[str]): + """Test flatten_diff_fields with various diff structures.""" + result = flatten_diff_fields(diff) + assert result == expected diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index d24382f35e..35a251b2cf 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -48,6 +48,7 @@ from dstack._internal.server.schemas.runs import ApplyRunPlanRequest from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.services.runs import run_model_to_run +from dstack._internal.server.services.runs.spec import validate_run_spec_and_set_defaults from dstack._internal.server.testing.common import ( create_backend, create_fleet, @@ -1549,6 +1550,7 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl repo = await create_repo(session=session, project_id=project.id) run_spec = get_run_spec( run_name="test-service", + configuration_path="old.dstack.yml", repo_id=repo.name, configuration=ServiceConfiguration( type="service", @@ -1557,6 +1559,8 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl replicas=Range(min=1, max=1), ), ) + # set defaults to avoid phantom changes being detected + validate_run_spec_and_set_defaults(user, run_spec) run_model = await create_run( session=session, project=project, @@ -1566,6 +1570,7 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl run_spec=run_spec, ) run = run_model_to_run(run_model) + run_spec.configuration_path = "new.dstack.yml" run_spec.configuration.replicas = Range(min=2, max=2) response = await client.post( f"/api/project/{project.name}/runs/apply", @@ -1586,8 +1591,15 @@ async def test_updates_run(self, test_db, session: AsyncSession, client: AsyncCl updated_run = run_model_to_run(run_model) assert run.deployment_num == 0 assert updated_run.deployment_num == 1 + assert run.run_spec.configuration_path == "old.dstack.yml" + assert updated_run.run_spec.configuration_path == "new.dstack.yml" assert run.run_spec.configuration.replicas == Range(min=1, max=1) assert updated_run.run_spec.configuration.replicas == Range(min=2, max=2) + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == ( + "Run updated. Deployment: 1. Changed fields: configuration_path, configuration.replicas" + ) @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)