Skip to content
42 changes: 42 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import builtins
import ipaddress
import uuid
import warnings
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -246,6 +247,7 @@ def Field(
exclude: Set[int | str] | Mapping[int | str, Any] | Any = None,
include: Set[int | str] | Mapping[int | str, Any] | Any = None,
const: bool | None = None,
coerce_numbers_to_str: bool | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
Expand All @@ -258,9 +260,12 @@ def Field(
unique_items: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
union_mode: Literal["smart", "left_to_right"] | None = None,
fail_fast: bool | None = None,
allow_mutation: bool = True,
regex: str | None = None,
discriminator: str | None = None,
validate_default: bool | None = None,
repr: bool = True,
primary_key: bool | UndefinedType = Undefined,
foreign_key: Any = Undefined,
Expand Down Expand Up @@ -289,6 +294,7 @@ def Field(
exclude: Set[int | str] | Mapping[int | str, Any] | Any = None,
include: Set[int | str] | Mapping[int | str, Any] | Any = None,
const: bool | None = None,
coerce_numbers_to_str: bool | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
Expand All @@ -301,9 +307,12 @@ def Field(
unique_items: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
union_mode: Literal["smart", "left_to_right"] | None = None,
fail_fast: bool | None = None,
allow_mutation: bool = True,
regex: str | None = None,
discriminator: str | None = None,
validate_default: bool | None = None,
repr: bool = True,
primary_key: bool | UndefinedType = Undefined,
foreign_key: str,
Expand Down Expand Up @@ -341,6 +350,7 @@ def Field(
exclude: Set[int | str] | Mapping[int | str, Any] | Any = None,
include: Set[int | str] | Mapping[int | str, Any] | Any = None,
const: bool | None = None,
coerce_numbers_to_str: bool | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
Expand All @@ -353,9 +363,12 @@ def Field(
unique_items: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
union_mode: Literal["smart", "left_to_right"] | None = None,
fail_fast: bool | None = None,
allow_mutation: bool = True,
regex: str | None = None,
discriminator: str | None = None,
validate_default: bool | None = None,
repr: bool = True,
sa_column: Column[Any] | UndefinedType = Undefined,
schema_extra: dict[str, Any] | None = None,
Expand All @@ -374,6 +387,7 @@ def Field(
exclude: Set[int | str] | Mapping[int | str, Any] | Any = None,
include: Set[int | str] | Mapping[int | str, Any] | Any = None,
const: bool | None = None,
coerce_numbers_to_str: bool | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
Expand All @@ -386,9 +400,12 @@ def Field(
unique_items: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
union_mode: Literal["smart", "left_to_right"] | None = None,
fail_fast: bool | None = None,
allow_mutation: bool = True,
regex: str | None = None,
discriminator: str | None = None,
validate_default: bool | None = None,
repr: bool = True,
primary_key: bool | UndefinedType = Undefined,
foreign_key: Any = Undefined,
Expand All @@ -403,16 +420,36 @@ def Field(
schema_extra: dict[str, Any] | None = None,
) -> Any:
current_schema_extra = schema_extra or {}

for param_name in (
"coerce_numbers_to_str",
"validate_default",
"union_mode",
"fail_fast",
):
if param_name in current_schema_extra:
msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
warnings.warn(msg, UserWarning, stacklevel=2)

# Extract possible alias settings from schema_extra so we can control precedence
schema_validation_alias = current_schema_extra.pop("validation_alias", None)
schema_serialization_alias = current_schema_extra.pop("serialization_alias", None)
current_coerce_numbers_to_str = coerce_numbers_to_str or current_schema_extra.pop(
"coerce_numbers_to_str", None
)
current_validate_default = validate_default or current_schema_extra.pop(
"validate_default", None
)
current_fail_fast = fail_fast or current_schema_extra.pop("fail_fast", None)
field_info_kwargs = {
"alias": alias,
"title": title,
"description": description,
"exclude": exclude,
"include": include,
"const": const,
"coerce_numbers_to_str": current_coerce_numbers_to_str,
"validate_default": current_validate_default,
"gt": gt,
"ge": ge,
"lt": lt,
Expand All @@ -425,6 +462,7 @@ def Field(
"unique_items": unique_items,
"min_length": min_length,
"max_length": max_length,
"fail_fast": current_fail_fast,
"allow_mutation": allow_mutation,
"regex": regex,
"discriminator": discriminator,
Expand All @@ -450,6 +488,10 @@ def Field(
serialization_alias or schema_serialization_alias or alias
)

current_union_mode = union_mode or current_schema_extra.pop("union_mode", None)
if current_union_mode is not None:
field_info_kwargs["union_mode"] = current_union_mode

field_info = FieldInfo(
default,
default_factory=default_factory,
Expand Down
191 changes: 190 additions & 1 deletion tests/test_pydantic/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from pydantic import ValidationError
from sqlmodel import Field, SQLModel
from sqlmodel import Field, Session, SQLModel, create_engine


def test_decimal():
Expand Down Expand Up @@ -54,3 +54,192 @@ class Model(SQLModel):

instance = Model(id=123, foo="bar")
assert "foo=" not in repr(instance)


def test_coerce_numbers_to_str_true():
class Model(SQLModel):
val: str = Field(coerce_numbers_to_str=True)

assert Model.model_validate({"val": 123}).val == "123"
assert Model.model_validate({"val": 45.67}).val == "45.67"


@pytest.mark.parametrize("coerce_numbers_to_str", [None, False])
def test_coerce_numbers_to_str_false(coerce_numbers_to_str: bool | None):
class Model2(SQLModel):
val: str = Field(coerce_numbers_to_str=coerce_numbers_to_str)

with pytest.raises(ValidationError):
Model2.model_validate({"val": 123})


def test_coerce_numbers_to_str_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `coerce_numbers_to_str` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: str = Field(schema_extra={"coerce_numbers_to_str": True})

assert Model.model_validate({"val": 123}).val == "123"
assert Model.model_validate({"val": 45.67}).val == "45.67"


def test_validate_default_true():
class Model(SQLModel):
val: int = Field(default="123", validate_default=True)

assert Model.model_validate({}).val == 123

class Model2(SQLModel):
val: int = Field(default=None, validate_default=True)

with pytest.raises(ValidationError):
Model2.model_validate({})


def test_validate_default_table_model():
class Model(SQLModel):
id: int | None = Field(default=None, primary_key=True)
val: int = Field(default="123", validate_default=True)

class ModelDB(Model, table=True):
pass

engine = create_engine("sqlite://", echo=True)

SQLModel.metadata.create_all(engine)

model = ModelDB()
with Session(engine) as session:
session.add(model)
session.commit()
session.refresh(model)

assert model.val == 123


@pytest.mark.parametrize("validate_default", [None, False])
def test_validate_default_false(validate_default: bool | None):
class Model3(SQLModel):
val: int = Field(default="123", validate_default=validate_default)

assert Model3().val == "123"


def test_validate_default_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `validate_default` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: int = Field(default="123", schema_extra={"validate_default": True})

assert Model.model_validate({}).val == 123


@pytest.mark.parametrize("union_mode", [None, "smart"])
def test_union_mode_smart(union_mode: Literal["smart"] | None):
class Model(SQLModel):
val: float | int = Field(union_mode=union_mode)

a = Model.model_validate({"val": 123})
assert isinstance(a.val, int) # float is first, but int is more precise

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_union_mode_left_to_right():
class Model(SQLModel):
val: float | int = Field(union_mode="left_to_right")

a = Model.model_validate({"val": 123})
assert isinstance(a.val, float)

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_union_mode_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `union_mode` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: float | int = Field(schema_extra={"union_mode": "smart"})

a = Model.model_validate({"val": 123})
assert isinstance(a.val, int) # float is first, but int is more precise

b = Model.model_validate({"val": 123.0})
assert isinstance(b.val, float)

c = Model.model_validate({"val": 123.1})
assert isinstance(c.val, float)


def test_fail_fast_true():
class Model(SQLModel):
val: list[int] = Field(fail_fast=True)

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 1
assert errors[0]["type"] == "int_from_float"


@pytest.mark.parametrize("fail_fast", [None, False])
def test_fail_fast_false(fail_fast: bool | None):
class Model(SQLModel):
val: list[int] = Field(fail_fast=fail_fast)

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 2
error_types = {error["type"] for error in errors}

assert "int_from_float" in error_types
assert "int_parsing" in error_types


def test_fail_fast_via_schema_extra(): # Current workaround. Remove after some time
with pytest.warns(
UserWarning,
match=(
"Pass `fail_fast` parameter directly to Field instead of passing "
"it via `schema_extra`"
),
):

class Model(SQLModel):
val: list[int] = Field(schema_extra={"fail_fast": True})

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({"val": [1.1, "not an int"]})

errors = exc_info.value.errors()
assert len(errors) == 1
assert errors[0]["type"] == "int_from_float"
Loading