diff --git a/go.mod b/go.mod index f377f33a8..ee3bf68b2 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/google/go-cmp v0.7.0 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-version v1.8.0 + github.com/jackc/pgx/v5 v5.8.0 github.com/lensesio/tableprinter v0.0.0-20201125135848-89e81fc956e7 github.com/lib/pq v1.12.0 github.com/mark3labs/mcp-go v0.46.0 @@ -78,6 +79,9 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.18.2 // indirect @@ -115,7 +119,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 67bb38579..3bfdb9fbd 100644 --- a/go.sum +++ b/go.sum @@ -112,6 +112,14 @@ github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23 h1:M8exrBzuhWcU6aoHJlHWPe4qFjVKzkMGRal78f5jRRU= github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23/go.mod h1:kBSna6b0/RzsOcOZf515vAXwSsXYusl2U7SA0XP09yI= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= @@ -120,7 +128,6 @@ github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uq github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/connect-compress/v2 v2.1.0 h1:8fM8QrVeHT69e5VVSh4yjDaQASYIvOp2uMZq7nVLj2U= github.com/klauspost/connect-compress/v2 v2.1.0/go.mod h1:Ayurh2wscMMx3AwdGGVL+ylSR5316WfApREDgsqHyH8= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -218,6 +225,7 @@ github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjb github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/internal/cmd/importcmd/d1.go b/internal/cmd/importcmd/d1.go new file mode 100644 index 000000000..4d9e7f5e2 --- /dev/null +++ b/internal/cmd/importcmd/d1.go @@ -0,0 +1,142 @@ +package importcmd + +import ( + "fmt" + + "github.com/spf13/cobra" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" +) + +const defaultD1Branch = "main" + +var d1DatabaseBranchArgs = cobra.RangeArgs(1, 2) + +func parseDatabaseBranch(args []string) (database, branch string) { + database = args[0] + branch = defaultD1Branch + if len(args) > 1 { + branch = args[1] + } + return database, branch +} + +func d1Org(ch *cmdutil.Helper) string { + return ch.Config.Organization +} + +func writeD1(ch *cmdutil.Helper, resp d1.Response) error { + if resp.Status == "error" { + switch ch.Printer.Format() { + case printer.JSON: + if err := ch.Printer.PrintJSON(resp); err != nil { + return err + } + case printer.Human: + d1.PrintHumanResponse(ch.Printer, resp) + default: + return fmt.Errorf(`import d1 does not support output format %q (use human or json)`, ch.Printer.Format()) + } + return d1CommandError(resp) + } + + switch ch.Printer.Format() { + case printer.JSON: + return ch.Printer.PrintJSON(resp) + case printer.Human: + d1.PrintHumanResponse(ch.Printer, resp) + return nil + default: + return fmt.Errorf(`import d1 does not support output format %q (use human or json)`, ch.Printer.Format()) + } +} + +func d1CommandError(resp d1.Response) error { + msg := "import d1 command failed" + if resp.Error != nil { + msg = resp.Error.Message + if resp.Error.Remediation != "" { + msg += "\n" + resp.Error.Remediation + } + } + return &cmdutil.Error{ + Msg: msg, + ExitCode: cmdutil.ActionRequestedExitCode, + Printed: true, + } +} + +func d1NotifyAPI(client *ps.Client, disabled bool) d1.NotifyAPIConfig { + return d1.NotifyAPIConfig{Client: client, Disabled: disabled} +} + +func importTableCount(prepared *d1.ImportPrepareResult) int { + if prepared == nil || prepared.Plan == nil { + return 0 + } + return countDataTables(prepared.Plan.Tables) +} + +func verifyTableCount(org, database, branch, migrationID, inputPath string) int { + path := inputPath + if path == "" && migrationID != "" { + if state, err := d1.LoadState(org, database, branch, migrationID); err == nil { + path = state.InputPath + } + } + if path == "" { + return 0 + } + tables, err := d1.ParseDump(path) + if err != nil { + return 0 + } + n := 0 + for _, t := range tables { + if !d1.IsORMMetadataTable(t.Name) { + n++ + } + } + return n +} + +func countDataTables(tables []d1.TablePlan) int { + n := 0 + for _, table := range tables { + if !d1.IsORMMetadataTable(table.Name) { + n++ + } + } + return n +} + +// D1Cmd returns the import d1 subcommand group. +func D1Cmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "d1 ", + Short: "Import Cloudflare D1 into PlanetScale Postgres", + Long: `Offline import from Cloudflare D1 (SQLite) to PlanetScale Postgres. + +Export your D1 database with wrangler (wrangler d1 export --remote --output ./d1-export.sql), +lint the dump, then start the import (use --dry-run to preview). +All commands support --format json for machine-readable output. + +Branch-scoped commands use the same positional form as other PlanetScale CLI commands: + pscale import d1 start [branch] --input ./d1-export.sql +Org comes from your pscale config (pscale org).`, + } + + cmd.AddCommand(d1DoctorCmd(ch)) + cmd.AddCommand(d1LintCmd(ch)) + cmd.AddCommand(d1ConvertSchemaCmd(ch)) + cmd.AddCommand(d1StartCmd(ch)) + cmd.AddCommand(d1VerifyCmd(ch)) + cmd.AddCommand(d1StatusCmd(ch)) + cmd.AddCommand(d1CompleteCmd(ch)) + + return cmd +} diff --git a/internal/cmd/importcmd/d1_complete.go b/internal/cmd/importcmd/d1_complete.go new file mode 100644 index 000000000..947a6c6d2 --- /dev/null +++ b/internal/cmd/importcmd/d1_complete.go @@ -0,0 +1,52 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" +) + +func d1CompleteCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + migrationID string + force bool + noNotify bool + } + + cmd := &cobra.Command{ + Use: "complete [branch]", + Aliases: []string{"teardown"}, + Short: "Mark a D1 migration as complete in local state", + Args: d1DatabaseBranchArgs, + Example: ` pscale import d1 complete mydb --migration-id abc123 + pscale import d1 complete mydb --migration-id abc123 --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + database, branch := parseDatabaseBranch(args) + if !flags.force && ch.Printer.Format() == printer.Human { + if err := ch.Printer.ConfirmCommand(flags.migrationID, "import d1 complete", "complete"); err != nil { + return err + } + } + client, err := ch.Client() + if err != nil { + return writeD1(ch, d1.ErrorResponse("complete", err)) + } + resp, err := d1.CompleteResponse(d1Org(ch), database, branch, flags.migrationID) + if err != nil { + return writeD1(ch, d1.ErrorResponse("complete", err)) + } + if err := d1.Complete(d1Org(ch), database, branch, flags.migrationID, d1NotifyAPI(client, flags.noNotify)); err != nil { + return writeD1(ch, d1.ErrorResponse("complete", err)) + } + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID") + cmd.Flags().BoolVar(&flags.force, "force", false, "Skip confirmation prompt") + cmd.Flags().BoolVar(&flags.noNotify, "no-notify", false, "Skip Slack notifications for this completion") + cmd.MarkFlagRequired("migration-id") + return cmd +} diff --git a/internal/cmd/importcmd/d1_complete_test.go b/internal/cmd/importcmd/d1_complete_test.go new file mode 100644 index 000000000..d45642966 --- /dev/null +++ b/internal/cmd/importcmd/d1_complete_test.go @@ -0,0 +1,89 @@ +package importcmd + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" +) + +func TestD1CompleteCmd(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + client, err := ps.NewClient( + ps.WithBaseURL(srv.URL), + ps.WithAccessToken("token"), + ) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + const migrationID = "completecmd123" + fixture := d1FixturePath(t) + if err := d1.SavePlan(&d1.PlanResult{ + MigrationID: migrationID, + Org: "acme", + Database: "mydb", + Branch: "main", + InputPath: fixture, + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + if err := d1.SetMigrationPhase("acme", "mydb", "main", migrationID, d1.PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase: %v", err) + } + + var buf bytes.Buffer + format := printer.JSON + p := printer.NewPrinter(&format) + p.SetResourceOutput(&buf) + + ch := &cmdutil.Helper{ + Printer: p, + Config: &config.Config{Organization: "acme"}, + Client: func() (*ps.Client, error) { + return client, nil + }, + } + + cmd := d1CompleteCmd(ch) + cmd.SetArgs([]string{"mydb", "--migration-id", migrationID, "--force"}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + if err := cmd.Execute(); err != nil { + t.Fatalf("execute: %v", err) + } + + assertJSONField(t, &buf, "command", "complete") + assertJSONField(t, &buf, "status", "ok") + assertJSONField(t, &buf, "migration_id", migrationID) + if !strings.Contains(buf.String(), "reminder") { + t.Fatalf("expected reminder in complete JSON output:\n%s", buf.String()) + } + if !strings.Contains(buf.String(), "next_steps") { + t.Fatalf("expected next_steps in complete JSON output:\n%s", buf.String()) + } +} + +func TestD1CompleteCmdRequiresMigrationID(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1CompleteCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb"); err == nil { + t.Fatal("expected error when --migration-id is missing") + } +} diff --git a/internal/cmd/importcmd/d1_convert_schema.go b/internal/cmd/importcmd/d1_convert_schema.go new file mode 100644 index 000000000..7cc8d2082 --- /dev/null +++ b/internal/cmd/importcmd/d1_convert_schema.go @@ -0,0 +1,40 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" +) + +func d1ConvertSchemaCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + input string + output string + } + + cmd := &cobra.Command{ + Use: "convert-schema", + Short: "Convert SQLite schema in a D1 export to PostgreSQL DDL", + RunE: func(cmd *cobra.Command, args []string) error { + if flags.output == "" { + flags.output = flags.input + ".postgres.sql" + } + count, err := d1.ConvertSchema(flags.input, flags.output) + if err != nil { + return writeD1(ch, d1.ErrorResponse("convert-schema", err)) + } + resp := d1.OKResponse("convert-schema", map[string]any{ + "input": flags.input, + "output": flags.output, + "table_count": count, + }, nil) + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.Flags().StringVar(&flags.output, "output", "", "Output PostgreSQL schema file") + cmd.MarkFlagRequired("input") + return cmd +} diff --git a/internal/cmd/importcmd/d1_convert_schema_test.go b/internal/cmd/importcmd/d1_convert_schema_test.go new file mode 100644 index 000000000..fbc3e9236 --- /dev/null +++ b/internal/cmd/importcmd/d1_convert_schema_test.go @@ -0,0 +1,29 @@ +package importcmd + +import ( + "path/filepath" + "testing" +) + +func TestD1ConvertSchemaCmd(t *testing.T) { + ch, buf := newD1TestHelper(t) + fixture := d1FixturePath(t) + output := filepath.Join(t.TempDir(), "schema.postgres.sql") + + cmd := d1ConvertSchemaCmd(ch) + if err := executeD1Cmd(t, cmd, "--input", fixture, "--output", output); err != nil { + t.Fatalf("execute: %v", err) + } + + assertJSONField(t, buf, "command", "convert-schema") + assertJSONField(t, buf, "status", "ok") +} + +func TestD1ConvertSchemaCmdRequiresInput(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1ConvertSchemaCmd(ch) + if err := executeD1Cmd(t, cmd); err == nil { + t.Fatal("expected error when --input is missing") + } +} diff --git a/internal/cmd/importcmd/d1_doctor.go b/internal/cmd/importcmd/d1_doctor.go new file mode 100644 index 000000000..acd15718f --- /dev/null +++ b/internal/cmd/importcmd/d1_doctor.go @@ -0,0 +1,24 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" +) + +func d1DoctorCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "doctor", + Short: "Check prerequisites for D1 migration", + RunE: func(cmd *cobra.Command, args []string) error { + result, err := d1.Doctor(cmd.Context()) + if err != nil { + return writeD1(ch, d1.ErrorResponse("doctor", err)) + } + return writeD1(ch, d1.DoctorResponse(result)) + }, + } + + return cmd +} diff --git a/internal/cmd/importcmd/d1_doctor_test.go b/internal/cmd/importcmd/d1_doctor_test.go new file mode 100644 index 000000000..9cd2ac246 --- /dev/null +++ b/internal/cmd/importcmd/d1_doctor_test.go @@ -0,0 +1,37 @@ +package importcmd + +import ( + "encoding/json" + "testing" +) + +func TestD1DoctorCmd(t *testing.T) { + ch, buf := newD1TestHelper(t) + + cmd := d1DoctorCmd(ch) + err := executeD1Cmd(t, cmd) + if buf.Len() == 0 { + t.Fatal("expected JSON output") + } + assertJSONField(t, buf, "command", "doctor") + if err == nil { + assertJSONField(t, buf, "status", "ok") + return + } + + var resp map[string]any + if unmarshalErr := json.Unmarshal(buf.Bytes(), &resp); unmarshalErr != nil { + t.Fatalf("unmarshal output: %v\nbody: %s", unmarshalErr, buf.String()) + } + if resp["status"] != "error" { + t.Fatalf("status = %v, want error", resp["status"]) + } + data, ok := resp["data"].(map[string]any) + if !ok { + t.Fatalf("data = %T, want object", resp["data"]) + } + checks, ok := data["checks"].([]any) + if !ok || len(checks) == 0 { + t.Fatalf("checks = %v, want non-empty array", data["checks"]) + } +} diff --git a/internal/cmd/importcmd/d1_lint.go b/internal/cmd/importcmd/d1_lint.go new file mode 100644 index 000000000..b2cac550c --- /dev/null +++ b/internal/cmd/importcmd/d1_lint.go @@ -0,0 +1,32 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" +) + +func d1LintCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + input string + } + + cmd := &cobra.Command{ + Use: "lint", + Short: "Analyze a D1 SQL export for migration issues", + Example: ` pscale import d1 lint --input ./d1-export.sql --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := d1.Lint(flags.input) + if err != nil { + return writeD1(ch, d1.ErrorResponse("lint", err)) + } + resp := d1.LintResponse(result) + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.MarkFlagRequired("input") + return cmd +} diff --git a/internal/cmd/importcmd/d1_lint_test.go b/internal/cmd/importcmd/d1_lint_test.go new file mode 100644 index 000000000..003d1633f --- /dev/null +++ b/internal/cmd/importcmd/d1_lint_test.go @@ -0,0 +1,29 @@ +package importcmd + +import ( + "testing" +) + +func TestD1LintCmd(t *testing.T) { + ch, buf := newD1TestHelper(t) + fixture := d1FixturePath(t) + + cmd := d1LintCmd(ch) + if err := executeD1Cmd(t, cmd, "--input", fixture); err != nil { + t.Fatalf("execute: %v", err) + } + + assertJSONField(t, buf, "command", "lint") + if status := jsonStatus(t, buf); status != "ok" && status != "warning" { + t.Fatalf("status = %q, want ok or warning", status) + } +} + +func TestD1LintCmdRequiresInput(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1LintCmd(ch) + if err := executeD1Cmd(t, cmd); err == nil { + t.Fatal("expected error when --input is missing") + } +} diff --git a/internal/cmd/importcmd/d1_progress.go b/internal/cmd/importcmd/d1_progress.go new file mode 100644 index 000000000..5ee643897 --- /dev/null +++ b/internal/cmd/importcmd/d1_progress.go @@ -0,0 +1,109 @@ +package importcmd + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" +) + +const ( + progressPhaseImport = "import" + progressPhaseVerify = "verify" +) + +type progressReporter struct { + printer *printer.Printer + handle *printer.ProgressHandle + jsonMode bool + phase string +} + +func newImportProgressReporter(ch *cmdutil.Helper, tableCount int, sizeBytes int64) *progressReporter { + r := &progressReporter{ + printer: ch.Printer, + jsonMode: ch.Printer.Format() == printer.JSON, + phase: progressPhaseImport, + } + if r.jsonMode { + return r + } + msg := "Importing D1 export" + if tableCount > 0 { + msg = fmt.Sprintf("Importing D1 export (%d tables", tableCount) + if sizeBytes > 0 { + msg += fmt.Sprintf(", %.1f GB", float64(sizeBytes)/(1024*1024*1024)) + } + msg += ")..." + } else { + msg += "..." + } + r.handle = ch.Printer.StartProgress(msg) + return r +} + +func newVerifyProgressReporter(ch *cmdutil.Helper, tableCount int) *progressReporter { + r := &progressReporter{ + printer: ch.Printer, + jsonMode: ch.Printer.Format() == printer.JSON, + phase: progressPhaseVerify, + } + if r.jsonMode { + return r + } + msg := "Verifying D1 import..." + if tableCount > 0 { + msg = fmt.Sprintf("Verifying D1 import (%d tables)...", tableCount) + } + r.handle = ch.Printer.StartProgress(msg) + return r +} + +func (r *progressReporter) Close() { + if r.handle != nil { + r.handle.Stop() + } +} + +func (r *progressReporter) Report(p d1.ImportProgress) { + msg := formatProgressMessage(p) + if r.jsonMode { + r.writeJSON(p, msg) + return + } + if r.handle != nil { + r.handle.Update(msg) + return + } + fmt.Fprintln(os.Stderr, msg) +} + +func (r *progressReporter) writeJSON(p d1.ImportProgress, message string) { + payload := map[string]any{ + "type": "progress", + "phase": r.phase, + "stage": p.Stage, + "message": message, + } + if p.Current > 0 { + payload["current"] = p.Current + } + if p.Total > 0 { + payload["total"] = p.Total + } + if p.Detail != "" { + payload["detail"] = p.Detail + } + raw, err := json.Marshal(payload) + if err != nil { + return + } + fmt.Fprintln(os.Stderr, string(raw)) +} + +func formatProgressMessage(p d1.ImportProgress) string { + return d1.FormatProgressMessage(p) +} diff --git a/internal/cmd/importcmd/d1_progress_test.go b/internal/cmd/importcmd/d1_progress_test.go new file mode 100644 index 000000000..f21047458 --- /dev/null +++ b/internal/cmd/importcmd/d1_progress_test.go @@ -0,0 +1,53 @@ +package importcmd + +import ( + "testing" + + "github.com/planetscale/cli/internal/import/d1" +) + +func TestFormatProgressMessage(t *testing.T) { + tests := []struct { + name string + in d1.ImportProgress + want string + }{ + { + name: "import sqlite staging", + in: d1.ImportProgress{Stage: d1.ImportStageSQLiteStaging}, + want: "Staging SQLite database from export...", + }, + { + name: "import pgloader table", + in: d1.ImportProgress{ + Stage: d1.ImportStagePgloader, + Current: 3, + Total: 19, + Detail: "team_members", + }, + want: "Loading table 3/19: team_members", + }, + { + name: "verify row counts", + in: d1.ImportProgress{ + Stage: d1.VerifyStageRowCounts, + Current: 2, + Total: 19, + Detail: "users (postgres)", + }, + want: "Counting rows 2/19: users (postgres)", + }, + { + name: "verify sequences", + in: d1.ImportProgress{Stage: d1.VerifyStageSequences}, + want: "Checking identity sequences...", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatProgressMessage(tt.in); got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/cmd/importcmd/d1_start.go b/internal/cmd/importcmd/d1_start.go new file mode 100644 index 000000000..ecb2b72ee --- /dev/null +++ b/internal/cmd/importcmd/d1_start.go @@ -0,0 +1,119 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" +) + +func d1StartCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + input string + method string + migrationID string + dbName string + dryRun bool + force bool + noNotify bool + } + + cmd := &cobra.Command{ + Use: "start [branch]", + Short: "Start importing a D1 export (lint + plan, then load)", + Long: `Runs lint and builds an import plan, then loads data into PlanetScale Postgres. +Requires pgloader on PATH — run import d1 doctor to verify prerequisites. + +Use --dry-run to lint and save migration state without touching Postgres.`, + Args: d1DatabaseBranchArgs, + Example: ` # Preview lint + plan and get a migration ID + pscale import d1 start mydb --input ./d1-export.sql --dry-run --format json + + # Run the import on a specific branch (human TTY prompts to confirm) + pscale import d1 start mydb dev --input ./d1-export.sql --method pgloader --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + database, branch := parseDatabaseBranch(args) + org := d1Org(ch) + + importOpts := d1.ImportOptions{ + Org: org, + Database: database, + Branch: branch, + InputPath: flags.input, + Method: flags.method, + MigrationID: flags.migrationID, + DBName: flags.dbName, + DryRun: flags.dryRun, + } + + prepared, err := d1.PrepareImport(importOpts) + if err != nil { + return writeD1(ch, d1.ErrorResponse("start", err)) + } + + if !prepared.CanProceed { + return writeD1(ch, d1.BlockedStartResponse(prepared, flags.dryRun)) + } + + if !flags.force && !flags.dryRun && ch.Printer.Format() == printer.Human { + d1.PrintStartPreview(ch.Printer, prepared) + if err := ch.Printer.ConfirmCommand(prepared.MigrationID, "import d1 start", "start"); err != nil { + return err + } + } + + client, err := ch.Client() + if err != nil { + return writeD1(ch, d1.ErrorResponse("start", err)) + } + importOpts.NotifyAPI = d1NotifyAPI(client, flags.noNotify) + + var progress *progressReporter + if !flags.dryRun { + tableCount := importTableCount(prepared) + progress = newImportProgressReporter(ch, tableCount, prepared.Plan.EstimatedSizeBytes) + importOpts.OnProgress = progress.Report + importOpts.PgloaderVerbose = ch.Debug() + } + + result, err := d1.Import(cmd.Context(), client, &d1.DefaultImportClient{Client: client}, importOpts, prepared) + if progress != nil { + progress.Close() + } + if err != nil { + resp := d1.ErrorResponse("start", err) + if result != nil { + resp.Data = result + if result.Lint != nil { + resp.Issues = result.Lint.Issues + } + resp.MigrationID = result.MigrationID + } else { + resp.MigrationID = prepared.MigrationID + } + return writeD1(ch, resp) + } + resp := d1.OKResponse("start", result, d1.StartNextSteps(result.MigrationID, database, branch, result.Method, flags.input, flags.dryRun)) + resp.MigrationID = result.MigrationID + resp.Issues = result.Lint.Issues + if flags.dryRun { + resp.Status = "dry_run" + resp.Phase = d1.PhasePlanned + } else { + resp.Phase = d1.PhaseImported + } + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.Flags().StringVar(&flags.method, "method", "", "Import method: pgloader (≥1GB) or psql (<1GB; schema via psql, data via pgloader)") + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Existing migration ID from a prior start --dry-run") + cmd.Flags().StringVar(&flags.dbName, "dbname", "postgres", "Destination PostgreSQL database name") + cmd.Flags().BoolVar(&flags.dryRun, "dry-run", false, "Lint and build import plan without loading Postgres") + cmd.Flags().BoolVar(&flags.force, "force", false, "Skip confirmation prompt") + cmd.Flags().BoolVar(&flags.noNotify, "no-notify", false, "Skip Slack notifications for this import") + cmd.MarkFlagRequired("input") + return cmd +} diff --git a/internal/cmd/importcmd/d1_start_test.go b/internal/cmd/importcmd/d1_start_test.go new file mode 100644 index 000000000..ec152ae9e --- /dev/null +++ b/internal/cmd/importcmd/d1_start_test.go @@ -0,0 +1,29 @@ +package importcmd + +import ( + "testing" +) + +func TestD1StartCmdDryRun(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + ch, buf := newD1TestHelper(t) + fixture := d1FixturePath(t) + + cmd := d1StartCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb", "--input", fixture, "--dry-run", "--force"); err != nil { + t.Fatalf("execute: %v", err) + } + + assertJSONField(t, buf, "command", "start") + assertJSONField(t, buf, "status", "dry_run") +} + +func TestD1StartCmdRequiresInput(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1StartCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb"); err == nil { + t.Fatal("expected error when --input is missing") + } +} diff --git a/internal/cmd/importcmd/d1_status.go b/internal/cmd/importcmd/d1_status.go new file mode 100644 index 000000000..e306ba056 --- /dev/null +++ b/internal/cmd/importcmd/d1_status.go @@ -0,0 +1,33 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" +) + +func d1StatusCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + migrationID string + } + + cmd := &cobra.Command{ + Use: "status [branch]", + Short: "Show local migration state", + Args: d1DatabaseBranchArgs, + Example: ` pscale import d1 status mydb --migration-id abc123`, + RunE: func(cmd *cobra.Command, args []string) error { + database, branch := parseDatabaseBranch(args) + state, err := d1.Status(d1Org(ch), database, branch, flags.migrationID) + if err != nil { + return writeD1(ch, d1.ErrorResponse("status", err)) + } + return writeD1(ch, d1.StatusResponse(state)) + }, + } + + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID") + cmd.MarkFlagRequired("migration-id") + return cmd +} diff --git a/internal/cmd/importcmd/d1_status_test.go b/internal/cmd/importcmd/d1_status_test.go new file mode 100644 index 000000000..ba5a0ab73 --- /dev/null +++ b/internal/cmd/importcmd/d1_status_test.go @@ -0,0 +1,42 @@ +package importcmd + +import ( + "testing" + + "github.com/planetscale/cli/internal/import/d1" +) + +func TestD1StatusCmd(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + const migrationID = "statuscmd123" + fixture := d1FixturePath(t) + if err := d1.SavePlan(&d1.PlanResult{ + MigrationID: migrationID, + Org: "acme", + Database: "mydb", + Branch: "main", + InputPath: fixture, + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + ch, buf := newD1TestHelper(t) + cmd := d1StatusCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb", "--migration-id", migrationID); err != nil { + t.Fatalf("execute: %v", err) + } + + assertJSONField(t, buf, "command", "status") + assertJSONField(t, buf, "status", "ok") + assertJSONField(t, buf, "migration_id", migrationID) +} + +func TestD1StatusCmdRequiresMigrationID(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1StatusCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb"); err == nil { + t.Fatal("expected error when --migration-id is missing") + } +} diff --git a/internal/cmd/importcmd/d1_test.go b/internal/cmd/importcmd/d1_test.go new file mode 100644 index 000000000..7b89b85d1 --- /dev/null +++ b/internal/cmd/importcmd/d1_test.go @@ -0,0 +1,146 @@ +package importcmd + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "path/filepath" + "testing" + + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + "github.com/planetscale/cli/internal/import/d1" + "github.com/planetscale/cli/internal/printer" + ps "github.com/planetscale/planetscale-go/planetscale" +) + +func newD1TestHelper(t *testing.T) (*cmdutil.Helper, *bytes.Buffer) { + t.Helper() + + var buf bytes.Buffer + format := printer.JSON + p := printer.NewPrinter(&format) + p.SetResourceOutput(&buf) + + ch := &cmdutil.Helper{ + Printer: p, + Config: &config.Config{Organization: "acme"}, + Client: func() (*ps.Client, error) { + return &ps.Client{}, nil + }, + } + return ch, &buf +} + +func d1FixturePath(t *testing.T) string { + t.Helper() + return filepath.Clean(filepath.Join("..", "..", "import", "d1", "testdata", "sample_d1_export.sql")) +} + +func executeD1Cmd(t *testing.T, cmd *cobra.Command, args ...string) error { + t.Helper() + cmd.SetArgs(args) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + return cmd.Execute() +} + +func assertJSONField(t *testing.T, buf *bytes.Buffer, field string, want any) { + t.Helper() + got := jsonField(t, buf, field) + if got != want { + t.Fatalf("%s = %v, want %v", field, got, want) + } +} + +func jsonField(t *testing.T, buf *bytes.Buffer, field string) any { + t.Helper() + var resp map[string]any + if err := json.Unmarshal(buf.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal output: %v\nbody: %s", err, buf.String()) + } + got, ok := resp[field] + if !ok { + t.Fatalf("response missing %q\nbody: %s", field, buf.String()) + } + return got +} + +func jsonStatus(t *testing.T, buf *bytes.Buffer) string { + t.Helper() + got, ok := jsonField(t, buf, "status").(string) + if !ok { + t.Fatalf("status field is %T, want string", jsonField(t, buf, "status")) + } + return got +} + +func TestParseDatabaseBranch(t *testing.T) { + database, branch := parseDatabaseBranch([]string{"mydb"}) + if database != "mydb" || branch != "main" { + t.Fatalf("got (%q, %q), want (mydb, main)", database, branch) + } + + database, branch = parseDatabaseBranch([]string{"mydb", "dev"}) + if database != "mydb" || branch != "dev" { + t.Fatalf("got (%q, %q), want (mydb, dev)", database, branch) + } +} + +func TestWriteD1ErrorUsesConsistentExitCode(t *testing.T) { + resp := d1.LintResponse(&d1.LintResult{ + TableCount: 1, + ErrorCount: 1, + Issues: []d1.Issue{{ + Code: "VIRTUAL_TABLE", + Severity: d1.SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + }) + + for _, format := range []printer.Format{printer.Human, printer.JSON} { + t.Run(format.String(), func(t *testing.T) { + var buf bytes.Buffer + p := printer.NewPrinter(&format) + if format == printer.Human { + p.SetHumanOutput(&buf) + } else { + p.SetResourceOutput(&buf) + } + + err := writeD1(&cmdutil.Helper{Printer: p}, resp) + if err == nil { + t.Fatal("expected error") + } + + var cmdErr *cmdutil.Error + if !errors.As(err, &cmdErr) { + t.Fatalf("expected *cmdutil.Error, got %T: %v", err, err) + } + if cmdErr.ExitCode != cmdutil.ActionRequestedExitCode { + t.Fatalf("exit code = %d, want %d", cmdErr.ExitCode, cmdutil.ActionRequestedExitCode) + } + if !cmdErr.Printed { + t.Fatal("expected output to be marked printed") + } + if buf.Len() == 0 { + t.Fatal("expected response output") + } + }) + } +} + +func TestCountDataTablesSkipsORMMetadata(t *testing.T) { + got := countDataTables([]d1.TablePlan{ + {Name: "users"}, + {Name: "_prisma_migrations"}, + {Name: "posts"}, + }) + if got != 2 { + t.Fatalf("countDataTables = %d, want 2", got) + } +} diff --git a/internal/cmd/importcmd/d1_verify.go b/internal/cmd/importcmd/d1_verify.go new file mode 100644 index 000000000..1245dfcc3 --- /dev/null +++ b/internal/cmd/importcmd/d1_verify.go @@ -0,0 +1,82 @@ +package importcmd + +import ( + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/import/d1" +) + +func d1VerifyCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + migrationID string + input string + sqlite string + dbName string + noNotify bool + } + + cmd := &cobra.Command{ + Use: "verify [branch]", + Short: "Verify D1 import (row counts, sequences, coercion, content checks)", + Args: d1DatabaseBranchArgs, + Example: ` pscale import d1 verify mydb --migration-id abc123 --input ./d1-export.sql + pscale import d1 verify mydb dev --migration-id abc123 --input ./d1-export.sql --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + database, branch := parseDatabaseBranch(args) + org := d1Org(ch) + + verifyOpts := d1.VerifyOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: flags.migrationID, + InputPath: flags.input, + SQLitePath: flags.sqlite, + DBName: flags.dbName, + } + + client, err := ch.Client() + if err != nil { + return writeD1(ch, d1.ErrorResponse("verify", err)) + } + verifyOpts.NotifyAPI = d1NotifyAPI(client, flags.noNotify) + destURI, cleanup, err := d1.ResolveDestURI(cmd.Context(), client, d1.ImportOptions{ + Org: org, + Database: database, + Branch: branch, + DBName: flags.dbName, + }) + if err != nil { + return writeD1(ch, d1.ErrorResponse("verify", err)) + } + defer func() { _ = cleanup() }() + verifyOpts.DestURI = destURI + + progress := newVerifyProgressReporter(ch, verifyTableCount(org, database, branch, flags.migrationID, flags.input)) + verifyOpts.OnProgress = progress.Report + + result, err := d1.Verify(cmd.Context(), verifyOpts) + progress.Close() + if err != nil { + resp := d1.ErrorResponse("verify", err) + if result != nil { + resp.Data = result + } + return writeD1(ch, resp) + } + resp := d1.OKResponse("verify", result, d1.VerifyNextSteps(flags.migrationID, database, branch)) + resp.MigrationID = flags.migrationID + resp.Phase = d1.PhaseVerified + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID from plan/import") + cmd.Flags().StringVar(&flags.input, "input", "", "Path to original D1 SQL export") + cmd.Flags().StringVar(&flags.sqlite, "sqlite", "", "Path to local SQLite file for source counts") + cmd.Flags().StringVar(&flags.dbName, "dbname", "postgres", "Destination PostgreSQL database name") + cmd.Flags().BoolVar(&flags.noNotify, "no-notify", false, "Skip Slack notifications for this verification") + cmd.MarkFlagRequired("migration-id") + return cmd +} diff --git a/internal/cmd/importcmd/d1_verify_test.go b/internal/cmd/importcmd/d1_verify_test.go new file mode 100644 index 000000000..21c93be12 --- /dev/null +++ b/internal/cmd/importcmd/d1_verify_test.go @@ -0,0 +1,14 @@ +package importcmd + +import ( + "testing" +) + +func TestD1VerifyCmdRequiresMigrationID(t *testing.T) { + ch, _ := newD1TestHelper(t) + + cmd := d1VerifyCmd(ch) + if err := executeD1Cmd(t, cmd, "mydb"); err == nil { + t.Fatal("expected error when --migration-id is missing") + } +} diff --git a/internal/cmd/importcmd/import.go b/internal/cmd/importcmd/import.go new file mode 100644 index 000000000..3c740a66e --- /dev/null +++ b/internal/cmd/importcmd/import.go @@ -0,0 +1,27 @@ +package importcmd + +import ( + "github.com/planetscale/cli/internal/cmdutil" + "github.com/spf13/cobra" +) + +// ImportCmd returns the import command group. +func ImportCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "import", + Short: "Import external databases into PlanetScale Postgres", + Long: `Import databases from external sources into PlanetScale Postgres. + +Available sources: + d1 Import from Cloudflare D1 using an offline SQLite export`, + PersistentPreRunE: cmdutil.CheckAuthentication(ch.Config), + } + + cmd.PersistentFlags().StringVar(&ch.Config.Organization, "org", ch.Config.Organization, + "The organization for the current user") + cmd.MarkPersistentFlagRequired("org") // nolint:errcheck + + cmd.AddCommand(D1Cmd(ch)) + + return cmd +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 38f03e090..cc229d734 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -41,6 +41,7 @@ import ( "github.com/planetscale/cli/internal/cmd/database" "github.com/planetscale/cli/internal/cmd/dataimports" "github.com/planetscale/cli/internal/cmd/deployrequest" + "github.com/planetscale/cli/internal/cmd/importcmd" "github.com/planetscale/cli/internal/cmd/keyspace" "github.com/planetscale/cli/internal/cmd/org" "github.com/planetscale/cli/internal/cmd/password" @@ -118,16 +119,17 @@ func Execute(ctx context.Context, sigc chan os.Signal, signals []os.Signal, ver, return 0 } - // print any user specific messages first - switch format { - case printer.JSON: - fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) - default: - fmt.Fprintf(os.Stderr, "Error: %s\n", err) + var cmdErr *cmdutil.Error + printed := errors.As(err, &cmdErr) && cmdErr.Printed + if !printed { + switch format { + case printer.JSON: + fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + default: + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + } } - // check if a sub command wants to return a specific exit code - var cmdErr *cmdutil.Error if errors.As(err, &cmdErr) { return cmdErr.ExitCode } @@ -312,6 +314,10 @@ func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer. shellCmd.GroupID = "database" rootCmd.AddCommand(shellCmd) + importCmd := importcmd.ImportCmd(ch) + importCmd.GroupID = "postgres" + rootCmd.AddCommand(importCmd) + workflowCmd := workflow.WorkflowCmd(ch) workflowCmd.GroupID = "vitess" rootCmd.AddCommand(workflowCmd) diff --git a/internal/cmdutil/errors.go b/internal/cmdutil/errors.go index 2dd48bafd..28f32c44f 100644 --- a/internal/cmdutil/errors.go +++ b/internal/cmdutil/errors.go @@ -18,8 +18,11 @@ var errExpiredAuthMessage = errors.New("the access token has expired. Please run // Error can be used by a command to change the exit status of the CLI. type Error struct { Msg string - // Status + // ExitCode is returned to the shell when the command fails. ExitCode int + // Printed indicates the error output was already written (e.g. to stdout); + // root should not print Msg to stderr again. + Printed bool } func (e *Error) Error() string { return e.Msg } diff --git a/internal/import/d1/.gitignore b/internal/import/d1/.gitignore new file mode 100644 index 000000000..be303db03 --- /dev/null +++ b/internal/import/d1/.gitignore @@ -0,0 +1 @@ +*.fasl diff --git a/internal/import/d1/coerce_samples.go b/internal/import/d1/coerce_samples.go new file mode 100644 index 000000000..5d2aa92c6 --- /dev/null +++ b/internal/import/d1/coerce_samples.go @@ -0,0 +1,387 @@ +package d1 + +import ( + "bufio" + "encoding/json" + "os" + "regexp" + "slices" + "strings" +) + +var ( + uuidValueRe = regexp.MustCompile(`(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + timestampValueRe = regexp.MustCompile(`(?i)^\d{4}-\d{2}-\d{2}(?:[ T]\d{2}:\d{2}(?::\d{2}(?:\.\d+)?)?(?:Z|[+-]\d{2}:?\d{2})?)?$`) +) + +// ColumnSamples holds sampled INSERT values per table/column. +type ColumnSamples map[string]map[string][]string + +// TypeCoercionContext carries sampled values used to validate name-based coercions. +type TypeCoercionContext struct { + Samples ColumnSamples +} + +func BuildTypeCoercionContext(inputPath string, tables []TableSchema) (*TypeCoercionContext, error) { + samples, err := SampleColumnValues(inputPath, tables) + if err != nil { + return nil, err + } + return &TypeCoercionContext{Samples: samples}, nil +} + +func (ctx *TypeCoercionContext) samplesFor(table, column string) []string { + if ctx == nil || ctx.Samples == nil { + return nil + } + if cols, ok := ctx.Samples[table]; ok { + return cols[column] + } + return nil +} + +// SampleColumnValues reads INSERT statements and collects non-null literal values. +func SampleColumnValues(path string, tables []TableSchema) (ColumnSamples, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return nil, err + } + f, err := os.Open(clean) + if err != nil { + return nil, err + } + defer f.Close() + + tableCols := make(map[string][]string, len(tables)) + for _, t := range tables { + cols := make([]string, 0, len(t.Columns)) + for _, c := range t.Columns { + cols = append(cols, c.Name) + } + tableCols[t.Name] = cols + } + + samples := make(ColumnSamples) + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var pendingInsert strings.Builder + + flushInsert := func() { + line := strings.TrimSpace(pendingInsert.String()) + pendingInsert.Reset() + if line == "" { + return + } + m := insertRe.FindStringSubmatch(line) + if m == nil { + return + } + table := firstNonEmpty(m[1], m[2], m[3], m[4]) + columns, valueGroups, ok := parseInsertColumnsAndValues(line) + if !ok || len(valueGroups) == 0 { + return + } + if len(columns) == 0 { + columns = tableCols[table] + } + if len(columns) == 0 { + return + } + if samples[table] == nil { + samples[table] = make(map[string][]string) + } + for _, values := range valueGroups { + for i, col := range columns { + if i >= len(values) { + break + } + val := values[i] + if val == "" || strings.EqualFold(val, "NULL") { + continue + } + val = unquoteSQLLiteral(val) + samples[table][col] = appendUniqueSample(samples[table][col], val, 32) + } + } + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + + if pendingInsert.Len() > 0 { + pendingInsert.WriteString(" ") + pendingInsert.WriteString(line) + if strings.HasSuffix(line, ";") { + flushInsert() + } + continue + } + + m := insertRe.FindStringSubmatch(line) + if m == nil { + continue + } + if strings.HasSuffix(line, ";") { + pendingInsert.WriteString(line) + flushInsert() + continue + } + pendingInsert.WriteString(line) + } + flushInsert() + + if err := scanner.Err(); err != nil { + return nil, err + } + return samples, nil +} + +func appendUniqueSample(existing []string, val string, max int) []string { + if slices.Contains(existing, val) { + return existing + } + if len(existing) >= max { + return existing + } + return append(existing, val) +} + +func parseInsertColumnsAndValues(line string) (columns []string, valueGroups [][]string, ok bool) { + upper := strings.ToUpper(line) + valuesIdx := strings.Index(upper, " VALUES") + if valuesIdx < 0 { + return nil, nil, false + } + head := line[:valuesIdx] + valuesPart := strings.TrimSpace(line[valuesIdx+len(" VALUES"):]) + valuesPart = strings.TrimSuffix(valuesPart, ";") + + openParen := strings.Index(head, "(") + closeParen := strings.LastIndex(head, ")") + if openParen >= 0 && closeParen > openParen { + colPart := head[openParen+1 : closeParen] + for _, part := range splitCommaList(colPart) { + name, _ := parseColumnNameAndRest(strings.TrimSpace(part)) + if name != "" { + columns = append(columns, name) + } + } + } + + valuesPart = strings.TrimSpace(valuesPart) + if !strings.HasPrefix(valuesPart, "(") { + return columns, nil, false + } + + // Split one or more (...) value tuples. + var tuples []string + var current strings.Builder + depth := 0 + inSingle := false + inDouble := false + for i := 0; i < len(valuesPart); i++ { + c := valuesPart[i] + switch c { + case '\'': + if !inDouble { + if inSingle && i+1 < len(valuesPart) && valuesPart[i+1] == '\'' { + current.WriteByte(c) + current.WriteByte(valuesPart[i+1]) + i++ + continue + } + inSingle = !inSingle + } + current.WriteByte(c) + case '"': + if !inSingle { + if inDouble && i+1 < len(valuesPart) && valuesPart[i+1] == '"' { + current.WriteByte(c) + current.WriteByte(valuesPart[i+1]) + i++ + continue + } + inDouble = !inDouble + } + current.WriteByte(c) + case '(': + if !inSingle && !inDouble { + if depth == 0 { + current.Reset() + } + depth++ + } + if depth > 0 { + current.WriteByte(c) + } + case ')': + if !inSingle && !inDouble { + if depth > 0 { + current.WriteByte(c) + } + depth-- + if depth == 0 { + inner := strings.TrimSpace(current.String()) + inner = strings.TrimPrefix(inner, "(") + inner = strings.TrimSuffix(inner, ")") + tuples = append(tuples, inner) + current.Reset() + } + continue + } + current.WriteByte(c) + default: + if depth > 0 { + current.WriteByte(c) + } + } + } + + for _, tuple := range tuples { + valueGroups = append(valueGroups, splitInsertValues(tuple)) + } + return columns, valueGroups, len(valueGroups) > 0 +} + +func splitInsertValues(s string) []string { + var parts []string + var current strings.Builder + depth := 0 + inSingle := false + inDouble := false + + for i := 0; i < len(s); i++ { + c := s[i] + switch c { + case '\'': + if !inDouble { + if inSingle && i+1 < len(s) && s[i+1] == '\'' { + current.WriteByte(c) + current.WriteByte(s[i+1]) + i++ + continue + } + inSingle = !inSingle + } + current.WriteByte(c) + case '"': + if !inSingle { + if inDouble && i+1 < len(s) && s[i+1] == '"' { + current.WriteByte(c) + current.WriteByte(s[i+1]) + i++ + continue + } + inDouble = !inDouble + } + current.WriteByte(c) + case '(': + if !inSingle && !inDouble { + depth++ + } + current.WriteByte(c) + case ')': + if !inSingle && !inDouble { + depth-- + } + current.WriteByte(c) + case ',': + if depth == 0 && !inSingle && !inDouble { + parts = append(parts, strings.TrimSpace(current.String())) + current.Reset() + continue + } + current.WriteByte(c) + default: + current.WriteByte(c) + } + } + if current.Len() > 0 { + parts = append(parts, strings.TrimSpace(current.String())) + } + return parts +} + +func unquoteSQLLiteral(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + inner := s[1 : len(s)-1] + return strings.ReplaceAll(inner, "''", "'") + } + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + inner := s[1 : len(s)-1] + return strings.ReplaceAll(inner, `""`, `"`) + } + return s +} + +func looksLikeUUID(s string) bool { + return uuidValueRe.MatchString(strings.TrimSpace(s)) +} + +func looksLikeJSON(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + switch s[0] { + case '{', '[': + var v any + return json.Unmarshal([]byte(s), &v) == nil + default: + return false + } +} + +func looksLikeTimestamp(s string) bool { + s = strings.TrimSpace(s) + if s == "" { + return false + } + return timestampValueRe.MatchString(s) +} + +func samplesLookBoolean(table, column string, ctx *TypeCoercionContext) bool { + vals := ctx.samplesFor(table, column) + if len(vals) == 0 { + return false + } + for _, v := range vals { + if v == "" { + continue + } + if v != "0" && v != "1" { + return false + } + } + return true +} + +func samplesAllow(table, column string, ctx *TypeCoercionContext, allow func(string) bool) bool { + vals := ctx.samplesFor(table, column) + if len(vals) == 0 { + return false + } + for _, v := range vals { + if v != "" && !allow(v) { + return false + } + } + return true +} + +func samplesAllowUUID(table, column string, ctx *TypeCoercionContext) bool { + return samplesAllow(table, column, ctx, looksLikeUUID) +} + +func samplesAllowJSON(table, column string, ctx *TypeCoercionContext) bool { + return samplesAllow(table, column, ctx, looksLikeJSON) +} + +func samplesAllowTimestamp(table, column string, ctx *TypeCoercionContext) bool { + return samplesAllow(table, column, ctx, looksLikeTimestamp) +} diff --git a/internal/import/d1/coerce_samples_test.go b/internal/import/d1/coerce_samples_test.go new file mode 100644 index 000000000..b8d4d7a42 --- /dev/null +++ b/internal/import/d1/coerce_samples_test.go @@ -0,0 +1,40 @@ +package d1 + +import "testing" + +func TestLooksLikeTimestamp(t *testing.T) { + cases := map[string]bool{ + "2024-01-15 12:00:00": true, + "2024-01-15T12:00:00Z": true, + "v1-beta": false, + "note: pending": false, + "draft-2024": false, + "CURRENT_TIMESTAMP": false, + } + for val, want := range cases { + if got := looksLikeTimestamp(val); got != want { + t.Fatalf("looksLikeTimestamp(%q) = %v, want %v", val, got, want) + } + } +} + +func TestSampleColumnValuesExternalEntities(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatal(err) + } + samples, err := SampleColumnValues(testFixture(t), tables) + if err != nil { + t.Fatal(err) + } + if len(samples["external_entities"]["id"]) == 0 { + t.Fatalf("expected external_entities.id samples, got %#v", samples["external_entities"]) + } + ctx, err := BuildTypeCoercionContext(testFixture(t), tables) + if err != nil { + t.Fatal(err) + } + if !samplesAllowUUID("external_entities", "id", ctx) { + t.Fatal("expected uuid samples for external_entities.id") + } +} diff --git a/internal/import/d1/complete.go b/internal/import/d1/complete.go new file mode 100644 index 000000000..ece034a79 --- /dev/null +++ b/internal/import/d1/complete.go @@ -0,0 +1,124 @@ +package d1 + +import ( + "strings" + + "github.com/planetscale/cli/internal/printer" +) + +const completeReminderShort = "ORM migration tables were not imported. Re-baseline your migration history on Postgres before cutover." + +// CompleteResult is the data payload for import d1 complete. +type CompleteResult struct { + MigrationID string `json:"migration_id"` + Status string `json:"status"` + SkippedORMTables []string `json:"skipped_orm_tables,omitempty"` +} + +// CompleteResponse builds the success envelope for import d1 complete. +func CompleteResponse(org, database, branch, migrationID string) (Response, error) { + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + return Response{}, err + } + + skippedTables, nextSteps, err := completeORMNextSteps(state.InputPath) + if err != nil { + return Response{}, err + } + + resp := OKResponse("complete", CompleteResult{ + MigrationID: migrationID, + Status: PhaseComplete, + SkippedORMTables: skippedTables, + }, nextSteps) + resp.MigrationID = migrationID + resp.Phase = PhaseComplete + resp.Reminder = completeReminderShort + return resp, nil +} + +// CompleteSlackMessage returns a short Slack-friendly completion line. +func CompleteSlackMessage(skippedTables []string, orms []string) string { + if len(skippedTables) == 0 { + return "D1 import marked complete. If your app uses an ORM, re-baseline migration history on Postgres before cutover." + } + msg := "Data import complete. Next: re-baseline ORM migrations on Postgres — migration history tables from D1 were not imported." + if len(orms) > 0 { + msg += " Detected: " + strings.Join(orms, ", ") + "." + } + return msg +} + +func completeORMNextSteps(inputPath string) (skippedTables []string, steps []NextStep, err error) { + if inputPath == "" { + return nil, genericORMCompleteNextSteps(), nil + } + + tables, err := ParseDump(inputPath) + if err != nil { + return nil, nil, err + } + + seenORM := make(map[string]struct{}) + for _, table := range tables { + rule := ORMMetadataRule(table.Name) + if rule == nil { + continue + } + skippedTables = append(skippedTables, table.Name) + if _, ok := seenORM[rule.orm]; ok { + continue + } + seenORM[rule.orm] = struct{}{} + steps = append(steps, NextStep{ + Tool: rule.orm, + Reason: rule.remediation, + }) + } + + if len(steps) == 0 { + steps = genericORMCompleteNextSteps() + } + return skippedTables, steps, nil +} + +func genericORMCompleteNextSteps() []NextStep { + return []NextStep{{ + Tool: "ORM migrations", + Reason: "If your app uses an ORM or migration framework (Drizzle, Prisma, Rails, etc.), " + + "re-baseline migration history on Postgres. SQLite bookkeeping tables are never imported.", + }} +} + +func ormNamesFromSkippedTables(skippedTables []string) []string { + seen := make(map[string]struct{}) + var names []string + for _, table := range skippedTables { + rule := ORMMetadataRule(table) + if rule == nil { + continue + } + if _, ok := seen[rule.orm]; ok { + continue + } + seen[rule.orm] = struct{}{} + names = append(names, rule.orm) + } + return names +} + +func printCompleteReminderHuman(p *printer.Printer, result CompleteResult) { + p.Println("\nReminder: ORM migration history was not imported") + p.Println(" Application data is in Postgres, but framework tables such as") + p.Println(" __drizzle_migrations, _prisma_migrations, and schema_migrations") + p.Println(" were intentionally skipped.") + if len(result.SkippedORMTables) > 0 { + p.Printf(" Skipped in this export: %s\n", strings.Join(result.SkippedORMTables, ", ")) + } + p.Println("\n Before cutover:") + p.Println(" • Point your app at the PlanetScale Postgres branch") + p.Println(" • Re-baseline migrations on Postgres (do not copy SQLite history)") + p.Println(" • Run your ORM's mark-applied / fake-initial flow for the current schema") + p.Println("\n Run pscale import d1 lint --input for ORM-specific guidance.") +} diff --git a/internal/import/d1/complete_test.go b/internal/import/d1/complete_test.go new file mode 100644 index 000000000..88a98264a --- /dev/null +++ b/internal/import/d1/complete_test.go @@ -0,0 +1,78 @@ +package d1 + +import ( + "strings" + "testing" +) + +func TestCompleteResponseIncludesORMGuidance(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "complete-response-123" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase: %v", err) + } + + resp, err := CompleteResponse(org, database, branch, migrationID) + if err != nil { + t.Fatalf("CompleteResponse: %v", err) + } + if resp.Reminder == "" { + t.Fatal("expected reminder") + } + if len(resp.NextSteps) == 0 { + t.Fatal("expected ORM next steps") + } + + data, ok := resp.Data.(CompleteResult) + if !ok { + t.Fatalf("data type = %T, want CompleteResult", resp.Data) + } + if len(data.SkippedORMTables) == 0 { + t.Fatal("expected skipped ORM tables from sample fixture") + } + + foundDrizzle := false + foundPrisma := false + for _, step := range resp.NextSteps { + switch step.Tool { + case "Drizzle": + foundDrizzle = true + case "Prisma": + foundPrisma = true + } + } + if !foundDrizzle || !foundPrisma { + t.Fatalf("next steps = %#v, want Drizzle and Prisma", resp.NextSteps) + } +} + +func TestCompleteSlackMessageWithDetectedORMs(t *testing.T) { + msg := CompleteSlackMessage( + []string{"__drizzle_migrations", "_prisma_migrations"}, + []string{"Drizzle", "Prisma"}, + ) + if !strings.Contains(msg, "re-baseline ORM migrations") { + t.Fatalf("message = %q", msg) + } + if !strings.Contains(msg, "Drizzle, Prisma") { + t.Fatalf("message = %q", msg) + } +} + +func TestCompleteSlackMessageWithoutORMTables(t *testing.T) { + msg := CompleteSlackMessage(nil, nil) + if !strings.Contains(msg, "re-baseline migration history") { + t.Fatalf("message = %q", msg) + } +} diff --git a/internal/import/d1/constraints.go b/internal/import/d1/constraints.go new file mode 100644 index 000000000..bea8c5530 --- /dev/null +++ b/internal/import/d1/constraints.go @@ -0,0 +1,317 @@ +package d1 + +import ( + "regexp" + "strings" + + "github.com/planetscale/cli/internal/postgres" +) + +var ( + referencesClauseRe = regexp.MustCompile(`(?is)^REFERENCES\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*(.*)$`) + foreignKeyConstraintRe = regexp.MustCompile(`(?is)^FOREIGN\s+KEY\s*\(\s*([^)]+)\)\s*(REFERENCES\s+.+)$`) + primaryKeyConstraintRe = regexp.MustCompile(`(?is)^PRIMARY\s+KEY\s*\(\s*([^)]+)\)\s*$`) + uniqueConstraintRe = regexp.MustCompile(`(?is)^UNIQUE\s*\(\s*([^)]+)\)\s*$`) + createIndexRe = regexp.MustCompile(`(?is)^CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s+ON\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*;?\s*$`) +) + +// IndexSchema holds a parsed CREATE INDEX statement from a dump. +type IndexSchema struct { + Name string + Table string + Unique bool + Columns string + RawDDL string +} + +func isTableConstraint(part string) bool { + upper := strings.ToUpper(strings.TrimSpace(part)) + return strings.HasPrefix(upper, "PRIMARY KEY") || + strings.HasPrefix(upper, "FOREIGN KEY") || + strings.HasPrefix(upper, "UNIQUE(") || + strings.HasPrefix(upper, "UNIQUE (") || + strings.HasPrefix(upper, "CHECK(") || + strings.HasPrefix(upper, "CHECK (") || + strings.HasPrefix(upper, "CONSTRAINT ") +} + +func referencesClause(colDef string) string { + idx := indexOfIgnoreCase(colDef, "REFERENCES") + if idx < 0 { + return "" + } + return strings.TrimSpace(colDef[idx:]) +} + +func convertTableConstraint(clause string) string { + clause = strings.TrimSpace(clause) + clause = strings.TrimSuffix(clause, ",") + if clause == "" { + return "" + } + + upper := strings.ToUpper(clause) + switch { + case strings.HasPrefix(upper, "FOREIGN KEY"): + return convertForeignKeyConstraint(clause) + case strings.HasPrefix(upper, "PRIMARY KEY"): + return convertPrimaryKeyConstraint(clause) + case strings.HasPrefix(upper, "UNIQUE"): + return convertUniqueConstraint(clause) + case strings.HasPrefix(upper, "CHECK"): + return convertCheckConstraint(clause) + default: + return clause + } +} + +func convertCheckConstraint(clause string) string { + clause = strings.TrimSpace(clause) + clause = strings.TrimSuffix(clause, ",") + return clause +} + +func convertForeignKeyConstraint(clause string) string { + m := foreignKeyConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + cols := quoteColumnList(m[1]) + refs := convertReferencesClause(strings.TrimSpace(m[2])) + return "FOREIGN KEY (" + cols + ") " + refs +} + +func convertPrimaryKeyConstraint(clause string) string { + m := primaryKeyConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + return "PRIMARY KEY (" + quoteColumnList(m[1]) + ")" +} + +func convertUniqueConstraint(clause string) string { + m := uniqueConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + return "UNIQUE (" + quoteColumnList(m[1]) + ")" +} + +func convertReferencesClause(refs string) string { + m := referencesClauseRe.FindStringSubmatch(refs) + if m == nil { + return refs + } + table := postgres.QuoteIdentifier(firstNonEmpty(m[1], m[2], m[3], m[4])) + refCols := quoteColumnList(m[5]) + tail := strings.TrimSpace(m[6]) + if tail != "" { + return "REFERENCES " + table + " (" + refCols + ") " + tail + } + return "REFERENCES " + table + " (" + refCols + ")" +} + +func quoteColumnList(list string) string { + parts := splitCommaList(list) + quoted := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + quoted = append(quoted, postgres.QuoteIdentifier(strings.Trim(part, "`\"'"))) + } + return strings.Join(quoted, ", ") +} + +func splitCommaList(list string) []string { + var parts []string + var current strings.Builder + depth := 0 + inSingle := false + inDouble := false + + for _, r := range list { + switch r { + case '\'': + if !inDouble { + inSingle = !inSingle + } + current.WriteRune(r) + case '"': + if !inSingle { + inDouble = !inDouble + } + current.WriteRune(r) + case '(': + if !inSingle && !inDouble { + depth++ + } + current.WriteRune(r) + case ')': + if !inSingle && !inDouble { + depth-- + } + current.WriteRune(r) + case ',': + if depth == 0 && !inSingle && !inDouble { + parts = append(parts, current.String()) + current.Reset() + continue + } + current.WriteRune(r) + default: + current.WriteRune(r) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +func convertIndexDDL(raw string) string { + m := createIndexRe.FindStringSubmatch(raw) + if m == nil { + return raw + } + unique := strings.TrimSpace(m[1]) != "" + name := postgres.QuoteIdentifier(firstNonEmpty(m[2], m[3], m[4], m[5])) + table := postgres.QuoteIdentifier(firstNonEmpty(m[6], m[7], m[8], m[9])) + cols := quoteColumnList(m[10]) + prefix := "CREATE INDEX IF NOT EXISTS " + if unique { + prefix = "CREATE UNIQUE INDEX IF NOT EXISTS " + } + return prefix + name + " ON " + table + " (" + cols + ");" +} + +func isUUIDColumn(col ColumnSchema, table TableSchema, all []TableSchema, ctx *TypeCoercionContext) bool { + if isExplicitUUIDColumn(col) { + if ctx == nil { + return false + } + return samplesAllowUUID(table.Name, col.Name, ctx) + } + return columnReferencesUUIDKey(col, table, all, ctx) +} + +const maxUUIDFKDepth = 32 + +func columnReferencesUUIDKey(col ColumnSchema, table TableSchema, all []TableSchema, ctx *TypeCoercionContext) bool { + visited := make(map[string]struct{}) + return columnReferencesUUIDKeyVisited(col, table, all, ctx, visited, 0) +} + +func columnReferencesUUIDKeyVisited(col ColumnSchema, table TableSchema, all []TableSchema, ctx *TypeCoercionContext, visited map[string]struct{}, depth int) bool { + if depth >= maxUUIDFKDepth { + return false + } + key := table.Name + "." + col.Name + if _, seen := visited[key]; seen { + return false + } + visited[key] = struct{}{} + + refTable, refCol := columnFKTarget(col, table) + if refTable == "" { + return false + } + ref := tableByName(all, refTable) + if ref == nil { + return false + } + refColSchema := columnByName(*ref, refCol) + if isExplicitUUIDColumn(refColSchema) { + if ctx == nil { + return false + } + return samplesAllowUUID(ref.Name, refColSchema.Name, ctx) + } + return columnReferencesUUIDKeyVisited(refColSchema, *ref, all, ctx, visited, depth+1) +} + +func isExplicitUUIDColumn(col ColumnSchema) bool { + name := strings.ToLower(col.Name) + t := strings.ToUpper(col.Type) + + if !isTextLikeType(t) { + return false + } + + if col.PrimaryKey && (name == "id" || name == "uuid") { + return true + } + if strings.HasSuffix(name, "_uuid") { + return true + } + return false +} + +func columnFKTarget(col ColumnSchema, table TableSchema) (string, string) { + if col.ForeignKey != "" { + return parseReferencesTarget(col.ForeignKey) + } + for _, constraint := range table.Constraints { + cols, refs := parseTableLevelForeignKey(constraint) + for _, name := range cols { + if name == col.Name { + return parseReferencesTarget(refs) + } + } + } + return "", "" +} + +func parseTableLevelForeignKey(constraint string) ([]string, string) { + m := foreignKeyConstraintRe.FindStringSubmatch(constraint) + if m == nil { + return nil, "" + } + cols := make([]string, 0) + for _, part := range splitCommaList(m[1]) { + part = strings.Trim(strings.TrimSpace(part), "`\"'") + if part != "" { + cols = append(cols, part) + } + } + return cols, strings.TrimSpace(m[2]) +} + +func parseReferencesTarget(refs string) (string, string) { + m := referencesClauseRe.FindStringSubmatch(strings.TrimSpace(refs)) + if m == nil { + return "", "" + } + table := firstNonEmpty(m[1], m[2], m[3], m[4]) + refCols := splitCommaList(m[5]) + refCol := "" + if len(refCols) > 0 { + refCol = strings.Trim(strings.TrimSpace(refCols[0]), "`\"'") + } + return table, refCol +} + +func tableByName(all []TableSchema, name string) *TableSchema { + lower := strings.ToLower(name) + for i := range all { + if strings.ToLower(all[i].Name) == lower { + return &all[i] + } + } + return nil +} + +func columnByName(table TableSchema, name string) ColumnSchema { + lower := strings.ToLower(name) + for _, col := range table.Columns { + if strings.ToLower(col.Name) == lower { + return col + } + } + return ColumnSchema{} +} + +func isTextLikeType(t string) bool { + return t == "" || strings.Contains(t, "CHAR") || strings.Contains(t, "CLOB") || strings.Contains(t, "TEXT") +} diff --git a/internal/import/d1/constraints_test.go b/internal/import/d1/constraints_test.go new file mode 100644 index 000000000..35d4fc9c6 --- /dev/null +++ b/internal/import/d1/constraints_test.go @@ -0,0 +1,35 @@ +package d1 + +import "testing" + +func TestParseTableLevelForeignKey(t *testing.T) { + cols, refs := parseTableLevelForeignKey(`FOREIGN KEY (entity_id) REFERENCES external_entities(id)`) + if len(cols) != 1 || cols[0] != "entity_id" { + t.Fatalf("unexpected columns: %#v", cols) + } + refTable, refCol := parseReferencesTarget(refs) + if refTable != "external_entities" || refCol != "id" { + t.Fatalf("unexpected ref target: %s.%s", refTable, refCol) + } +} + +func TestColumnFKTargetUsesTableConstraint(t *testing.T) { + table := TableSchema{ + Name: "entity_links", + Columns: []ColumnSchema{ + {Name: "entity_id", Type: "TEXT", NotNull: true}, + {Name: "post_id", Type: "INTEGER", NotNull: true}, + }, + Constraints: []string{ + `PRIMARY KEY (entity_id, post_id)`, + `FOREIGN KEY (entity_id) REFERENCES external_entities(id)`, + `FOREIGN KEY (post_id) REFERENCES posts(id)`, + }, + } + col := table.Columns[0] + + refTable, refCol := columnFKTarget(col, table) + if refTable != "external_entities" || refCol != "id" { + t.Fatalf("got %s.%s", refTable, refCol) + } +} diff --git a/internal/import/d1/convert.go b/internal/import/d1/convert.go new file mode 100644 index 000000000..00f715b00 --- /dev/null +++ b/internal/import/d1/convert.go @@ -0,0 +1,334 @@ +package d1 + +import ( + "fmt" + "os" + "regexp" + "strconv" + "strings" + + "github.com/planetscale/cli/internal/postgres" +) + +var ( + sqliteTypeCleanup = regexp.MustCompile(`(?i)\s+PRIMARY\s+KEY\s+AUTOINCREMENT`) +) + +// SchemaParts holds table DDL and secondary index DDL separately so imports can +// load data before building indexes (much faster than maintaining indexes per row). +type SchemaParts struct { + Tables string + Indexes string +} + +// ConvertSchemaParts converts SQLite DDL into Postgres table and index SQL. +func ConvertSchemaParts(inputPath string) (SchemaParts, int, error) { + tables, err := ParseDump(inputPath) + if err != nil { + return SchemaParts{}, 0, err + } + + coerceCtx, err := BuildTypeCoercionContext(inputPath, tables) + if err != nil { + return SchemaParts{}, 0, err + } + + indexes, err := ParseIndexes(inputPath) + if err != nil { + return SchemaParts{}, 0, err + } + + var tableBuf strings.Builder + tableBuf.WriteString("-- Generated by pscale import d1 convert-schema (tables)\n") + fmt.Fprintf(&tableBuf, "-- Source: %s\n\n", inputPath) + + converted := 0 + tableByName := make(map[string]TableSchema, len(tables)) + for _, table := range tables { + tableByName[table.Name] = table + } + for _, name := range topologicalLoadOrder(tables) { + table, ok := tableByName[name] + if !ok { + continue + } + if IsORMMetadataTable(table.Name) { + continue + } + tableBuf.WriteString(convertTableDDL(table, tables, coerceCtx)) + tableBuf.WriteString("\n\n") + converted++ + } + + var indexBuf strings.Builder + if len(indexes) > 0 { + indexBuf.WriteString("-- Generated by pscale import d1 convert-schema (indexes)\n") + fmt.Fprintf(&indexBuf, "-- Source: %s\n\n", inputPath) + indexBuf.WriteString("-- Indexes\n") + for _, idx := range indexes { + if IsORMMetadataTable(idx.Table) { + continue + } + indexBuf.WriteString(convertIndexDDL(idx.RawDDL)) + indexBuf.WriteString("\n") + } + indexBuf.WriteString("\n") + } + + return SchemaParts{ + Tables: tableBuf.String(), + Indexes: indexBuf.String(), + }, converted, nil +} + +// ConvertSchema converts SQLite CREATE TABLE statements to PostgreSQL DDL. +func ConvertSchema(inputPath, outputPath string) (int, error) { + parts, converted, err := ConvertSchemaParts(inputPath) + if err != nil { + return 0, err + } + + var b strings.Builder + b.WriteString(parts.Tables) + if parts.Indexes != "" { + b.WriteString(parts.Indexes) + } + + if err := os.WriteFile(outputPath, []byte(b.String()), 0o600); err != nil { + return 0, fmt.Errorf("write schema: %w", err) + } + + return converted, nil +} + +func convertTableDDL(table TableSchema, all []TableSchema, ctx *TypeCoercionContext) string { + var b strings.Builder + fmt.Fprintf(&b, "CREATE TABLE IF NOT EXISTS %s (\n", postgres.QuoteIdentifier(table.Name)) + + var lines []string + for _, col := range table.Columns { + lines = append(lines, " "+convertColumn(col, table, all, ctx)) + } + for _, constraint := range table.Constraints { + if converted := convertTableConstraint(constraint); converted != "" { + lines = append(lines, " "+converted) + } + } + b.WriteString(strings.Join(lines, ",\n")) + b.WriteString("\n);\n") + + return b.String() +} + +func convertColumn(col ColumnSchema, table TableSchema, all []TableSchema, ctx *TypeCoercionContext) string { + pgType := sqliteTypeToPostgres(col, table, all, ctx) + + var parts []string + parts = append(parts, postgres.QuoteIdentifier(col.Name), pgType) + + if col.AutoIncrement { + parts = append(parts, "GENERATED BY DEFAULT AS IDENTITY") + if col.PrimaryKey { + parts = append(parts, "PRIMARY KEY") + } + } else if col.PrimaryKey { + parts = append(parts, "PRIMARY KEY") + } + + if col.NotNull && !col.AutoIncrement { + parts = append(parts, "NOT NULL") + } + + if col.Unique { + parts = append(parts, "UNIQUE") + } + + if col.DefaultValue != "" && !col.AutoIncrement { + parts = append(parts, "DEFAULT", convertDefault(col.DefaultValue, pgType)) + } + + if col.ForeignKey != "" { + parts = append(parts, convertReferencesClause(col.ForeignKey)) + } + + return strings.Join(parts, " ") +} + +func sqliteTypeToPostgres(col ColumnSchema, table TableSchema, all []TableSchema, ctx *TypeCoercionContext) string { + if isUUIDColumn(col, table, all, ctx) { + return "UUID" + } + + t := strings.ToUpper(col.Type) + + if col.AutoIncrement { + if strings.Contains(t, "BIG") { + return "BIGINT" + } + return "INTEGER" + } + + switch { + case t == "" || t == "NUMERIC": + return "TEXT" + case strings.Contains(t, "INT"): + if isBooleanLikeColumn(col, table, ctx) { + return "BOOLEAN" + } + return "BIGINT" + case strings.Contains(t, "CHAR") || strings.Contains(t, "CLOB") || strings.Contains(t, "TEXT"): + if isJSONText(col) && ctx != nil && samplesAllowJSON(table.Name, col.Name, ctx) { + return "JSONB" + } + if isTimestampText(col) && ctx != nil && samplesAllowTimestamp(table.Name, col.Name, ctx) { + return "TIMESTAMPTZ" + } + return "TEXT" + case strings.Contains(t, "BLOB"): + return "BYTEA" + case strings.Contains(t, "REAL") || strings.Contains(t, "FLOA") || strings.Contains(t, "DOUB"): + return "DOUBLE PRECISION" + case strings.Contains(t, "BOOL"): + return "BOOLEAN" + default: + return "TEXT" + } +} + +func convertDefault(def, pgType string) string { + def = strings.TrimSpace(def) + upper := strings.ToUpper(def) + if upper == "NULL" { + return "NULL" + } + if mapped := mapSQLiteDefaultFunction(def, pgType); mapped != "" { + return mapped + } + if pgType == "BOOLEAN" && (def == "0" || def == "1") { + if def == "1" { + return "TRUE" + } + return "FALSE" + } + if pgType == "UUID" { + def = strings.Trim(def, "'\"") + return "'" + def + "'" + } + if strings.HasPrefix(def, "'") || strings.HasPrefix(def, `"`) { + return def + } + if pgType == "TEXT" || pgType == "TIMESTAMPTZ" { + return "'" + strings.Trim(def, "'\"") + "'" + } + return def +} + +func mapSQLiteDefaultFunction(def, pgType string) string { + trimmed := strings.TrimSpace(def) + upper := strings.ToUpper(trimmed) + switch upper { + case "CURRENT_TIMESTAMP", "(CURRENT_TIMESTAMP)": + return "CURRENT_TIMESTAMP" + case "CURRENT_DATE", "(CURRENT_DATE)": + return "CURRENT_DATE" + case "CURRENT_TIME", "(CURRENT_TIME)": + return "CURRENT_TIME" + } + if strings.HasPrefix(upper, "DATETIME(") || strings.HasPrefix(upper, "(DATETIME(") { + return "now()" + } + if arg := sqliteFunctionArg(trimmed, "UNIXEPOCH"); arg != "" || strings.HasSuffix(upper, "UNIXEPOCH()") { + if mapped := mapUnixEpochDefault(arg, pgType); mapped != "" { + return mapped + } + } + return "" +} + +func mapUnixEpochDefault(arg, pgType string) string { + modifier := strings.ToUpper(strings.Trim(strings.TrimSpace(arg), `'"`)) + switch modifier { + case "", "NOW": + return unixEpochNowDefault(pgType) + case "SUBSEC": + return unixEpochSubsecDefault(pgType) + } + if pgType == "TIMESTAMPTZ" && unixEpochArgLooksNumeric(arg) { + return "to_timestamp(" + strings.TrimSpace(arg) + ")" + } + if (pgType == "BIGINT" || pgType == "INTEGER" || pgType == "DOUBLE PRECISION") && unixEpochArgLooksNumeric(arg) { + return strings.TrimSpace(arg) + } + return "" +} + +func unixEpochNowDefault(pgType string) string { + switch pgType { + case "TIMESTAMPTZ": + return "now()" + case "BIGINT", "INTEGER": + return "extract(epoch from now())::bigint" + case "DOUBLE PRECISION": + return "extract(epoch from now())" + default: + return "now()" + } +} + +func unixEpochSubsecDefault(pgType string) string { + switch pgType { + case "TIMESTAMPTZ": + return "clock_timestamp()" + case "BIGINT", "INTEGER": + return "extract(epoch from clock_timestamp())::bigint" + case "DOUBLE PRECISION": + return "extract(epoch from clock_timestamp())" + default: + return "clock_timestamp()" + } +} + +func unixEpochArgLooksNumeric(arg string) bool { + arg = strings.TrimSpace(arg) + if arg == "" { + return false + } + unquoted := strings.Trim(arg, `'"`) + if _, err := strconv.ParseFloat(unquoted, 64); err == nil { + return true + } + return !strings.HasPrefix(arg, "'") && !strings.HasPrefix(arg, `"`) +} + +func sqliteFunctionArg(s, fn string) string { + fnUpper := strings.ToUpper(fn) + idx := strings.Index(strings.ToUpper(s), fnUpper+"(") + if idx < 0 { + return "" + } + start := idx + len(fnUpper) + 1 + depth := 1 + for i := start; i < len(s); i++ { + switch s[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return strings.TrimSpace(s[start:i]) + } + } + } + return "" +} + +// ConvertCreateStatement converts a raw SQLite CREATE TABLE line to Postgres (for tests). +func ConvertCreateStatement(sqliteDDL string) string { + ddl := sqliteTypeCleanup.ReplaceAllString(sqliteDDL, "") + ddl = regexp.MustCompile(`(?i)\bAUTOINCREMENT\b`).ReplaceAllString(ddl, "") + ddl = regexp.MustCompile(`(?i)\bINTEGER\b`).ReplaceAllStringFunc(ddl, func(s string) string { + return "BIGINT" + }) + ddl = regexp.MustCompile(`(?i)\bREAL\b`).ReplaceAllString(ddl, "DOUBLE PRECISION") + return ddl +} diff --git a/internal/import/d1/convert_test.go b/internal/import/d1/convert_test.go new file mode 100644 index 000000000..4d4a11a8b --- /dev/null +++ b/internal/import/d1/convert_test.go @@ -0,0 +1,56 @@ +package d1 + +import "testing" + +func TestLooksLikeTimestampColumnName(t *testing.T) { + cases := map[string]bool{ + "created_at": true, + "updated_at": true, + "event_date": true, + "date_of_birth": true, + "date": true, + "timestamp_raw": true, + "candidate": false, + "mandate": false, + "metadata": false, + } + for name, want := range cases { + if got := looksLikeTimestampColumnName(name); got != want { + t.Fatalf("looksLikeTimestampColumnName(%q) = %v, want %v", name, got, want) + } + } +} + +func TestIsTimestampTextIgnoresFalsePositiveNames(t *testing.T) { + for _, name := range []string{"candidate", "mandate"} { + col := ColumnSchema{Name: name, Type: "TEXT"} + if isTimestampText(col) { + t.Fatalf("isTimestampText(%q) = true, want false", name) + } + } +} + +func TestMapSQLiteDefaultFunctionUnixEpoch(t *testing.T) { + cases := map[string]struct { + def string + pgType string + want string + }{ + "unixepoch('now') timestamptz": {"unixepoch('now')", "TIMESTAMPTZ", "now()"}, + "UNIXEPOCH('now') timestamptz": {"UNIXEPOCH('now')", "TIMESTAMPTZ", "now()"}, + "UnixEpoch('now') timestamptz": {"UnixEpoch('now')", "TIMESTAMPTZ", "now()"}, + "(UNIXEPOCH('now')) timestamptz": {"(UNIXEPOCH('now'))", "TIMESTAMPTZ", "now()"}, + "UNIXEPOCH('subsec') timestamptz": {"UNIXEPOCH('subsec')", "TIMESTAMPTZ", "clock_timestamp()"}, + "unixepoch() timestamptz": {"unixepoch()", "TIMESTAMPTZ", "now()"}, + "unixepoch('now') bigint": {"unixepoch('now')", "BIGINT", "extract(epoch from now())::bigint"}, + "unixepoch numeric timestamptz": {"UNIXEPOCH(1700000000)", "TIMESTAMPTZ", "to_timestamp(1700000000)"}, + "CURRENT_TIMESTAMP": {"CURRENT_TIMESTAMP", "TIMESTAMPTZ", "CURRENT_TIMESTAMP"}, + "datetime('now')": {"datetime('now')", "TIMESTAMPTZ", "now()"}, + } + for name, tc := range cases { + got := mapSQLiteDefaultFunction(tc.def, tc.pgType) + if got != tc.want { + t.Fatalf("%s: mapSQLiteDefaultFunction(%q, %q) = %q, want %q", name, tc.def, tc.pgType, got, tc.want) + } + } +} diff --git a/internal/import/d1/doctor.go b/internal/import/d1/doctor.go new file mode 100644 index 000000000..f184e4af9 --- /dev/null +++ b/internal/import/d1/doctor.go @@ -0,0 +1,224 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +const ( + checkOK = "ok" + checkWarn = "warn" + checkFail = "fail" + checkSkip = "skip" +) + +// Doctor runs prerequisite checks for D1 migration. +func Doctor(ctx context.Context) (*DoctorResult, error) { + checks := []DoctorCheck{ + checkWrangler(ctx), + checkPgloader(ctx), + checkPsql(), + checkSQLite3(ctx), + checkCloudflareEnv(), + } + + result := &DoctorResult{Checks: checks, Ready: true} + for _, c := range checks { + if c.Status == checkFail { + result.Ready = false + } + } + return result, nil +} + +func checkWrangler(ctx context.Context) DoctorCheck { + for _, cmd := range []string{"wrangler", "npx"} { + path, err := execabs.LookPath(cmd) + if err != nil { + continue + } + if cmd == "npx" { + c := execabs.CommandContext(ctx, path, "wrangler", "--version") + out, err := c.CombinedOutput() + if err == nil { + return DoctorCheck{ + Name: "wrangler", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } + } + continue + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err == nil { + return DoctorCheck{ + Name: "wrangler", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } + } + } + + return DoctorCheck{ + Name: "wrangler", + Status: checkWarn, + Message: "wrangler not found", + Remediation: wranglerMissingRemediation, + } +} + +func checkPgloader(ctx context.Context) DoctorCheck { + path, err := execabs.LookPath("pgloader") + if err != nil { + return DoctorCheck{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader not found", + Remediation: pgloaderInstallRemediation, + } + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err != nil { + return DoctorCheck{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader found but --version failed", + Remediation: "Reinstall pgloader", + } + } + return DoctorCheck{ + Name: "pgloader", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } +} + +func checkPsql() DoctorCheck { + major, minor, err := postgres.CheckPsqlVersion(10) + if err != nil { + return DoctorCheck{ + Name: "psql", + Status: checkFail, + Message: err.Error(), + Remediation: "Install PostgreSQL client tools: brew install postgresql@18", + } + } + return DoctorCheck{ + Name: "psql", + Status: checkOK, + Version: fmt.Sprintf("%d.%d", major, minor), + } +} + +func checkSQLite3(ctx context.Context) DoctorCheck { + path, err := execabs.LookPath("sqlite3") + if err != nil { + return DoctorCheck{ + Name: "sqlite3", + Status: checkSkip, + Message: "sqlite3 CLI not found", + Remediation: "Optional: install sqlite3 for verify and pgloader prep (brew install sqlite)", + } + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err != nil { + return DoctorCheck{ + Name: "sqlite3", + Status: checkSkip, + } + } + return DoctorCheck{ + Name: "sqlite3", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } +} + +func checkCloudflareEnv() DoctorCheck { + token := os.Getenv("CLOUDFLARE_API_TOKEN") + account := os.Getenv("CLOUDFLARE_ACCOUNT_ID") + if token != "" && account != "" { + return DoctorCheck{ + Name: "cloudflare_auth", + Status: checkOK, + } + } + return DoctorCheck{ + Name: "cloudflare_auth", + Status: checkWarn, + Message: "CLOUDFLARE_API_TOKEN and/or CLOUDFLARE_ACCOUNT_ID not set", + Remediation: "Set Cloudflare env vars for remote export, or pass --input with an existing dump", + } +} + +// DoctorReadinessError summarizes failed prerequisite checks for doctor/start. +func DoctorReadinessError(result *DoctorResult) error { + if result == nil || result.Ready { + return nil + } + + var parts []string + var remediations []string + for _, c := range result.Checks { + if c.Status != checkFail { + continue + } + msg := c.Name + if c.Message != "" { + msg += ": " + c.Message + } + parts = append(parts, msg) + if c.Remediation != "" { + remediations = append(remediations, c.Remediation) + } + } + + message := "prerequisites not met" + if len(parts) > 0 { + message = strings.Join(parts, "; ") + } + remediation := strings.Join(remediations, "; ") + if remediation == "" { + remediation = "Run `pscale import d1 doctor` and fix failed checks" + } + return newMigrationError(ErrCodePrereqFailed, message, remediation) +} + +// DoctorNextSteps suggests next actions after doctor. +func DoctorNextSteps(result *DoctorResult) []NextStep { + if !result.Ready { + return []NextStep{ + {Command: "pscale import d1 doctor", Reason: "Fix failed checks and re-run doctor"}, + } + } + return []NextStep{ + {Command: "wrangler d1 export --remote --output ./d1-export.sql", Reason: "Export D1 database with wrangler"}, + {Command: "pscale import d1 lint --input ./d1-export.sql", Reason: "Lint the export before import"}, + } +} + +// FindPgloader returns pgloader path. +func FindPgloader() (string, error) { + path, err := execabs.LookPath("pgloader") + if err != nil { + return "", errMissingTool("pgloader", pgloaderInstallRemediation) + } + return path, nil +} + +// FindSQLite3 returns sqlite3 path. +func FindSQLite3() (string, error) { + path, err := execabs.LookPath("sqlite3") + if err != nil { + return "", errMissingTool("sqlite3", "Install with: brew install sqlite") + } + return path, nil +} diff --git a/internal/import/d1/doctor_test.go b/internal/import/d1/doctor_test.go new file mode 100644 index 000000000..1c7a2ec2d --- /dev/null +++ b/internal/import/d1/doctor_test.go @@ -0,0 +1,62 @@ +package d1 + +import ( + "context" + "testing" +) + +func TestDoctor_RequiresPgloader(t *testing.T) { + if _, err := FindPgloader(); err == nil { + t.Skip("pgloader installed") + } + + result, err := Doctor(context.Background()) + if err != nil { + t.Fatalf("Doctor: %v", err) + } + if result.Ready { + t.Fatal("expected doctor not ready without pgloader") + } + + var pgloaderCheck DoctorCheck + for _, c := range result.Checks { + if c.Name == "pgloader" { + pgloaderCheck = c + break + } + } + if pgloaderCheck.Status != checkFail { + t.Fatalf("pgloader check status = %q, want %q", pgloaderCheck.Status, checkFail) + } + + if err := DoctorReadinessError(result); err == nil { + t.Fatal("expected readiness error") + } else { + requireMigrationErr(t, err, ErrCodePrereqFailed) + } +} + +func TestImport_RequiresPgloader(t *testing.T) { + if _, err := FindPgloader(); err == nil { + t.Skip("pgloader installed") + } + + result, err := Import(context.Background(), nil, nil, ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + }, nil) + if err == nil { + t.Fatal("expected missing pgloader error") + } + requireMigrationErr(t, err, ErrCodeMissingTool) + if result == nil { + t.Fatal("expected import result on failure") + } + if result.MigrationID == "" { + t.Fatal("expected migration_id in failure result") + } + if result.Lint == nil || result.Plan == nil { + t.Fatal("expected lint and plan in failure result") + } +} diff --git a/internal/import/d1/errors.go b/internal/import/d1/errors.go new file mode 100644 index 000000000..942d5fa74 --- /dev/null +++ b/internal/import/d1/errors.go @@ -0,0 +1,95 @@ +package d1 + +import ( + "errors" + "fmt" + "strings" +) + +// ErrCode constants for structured errors. +const ( + ErrCodeVirtualTable = "VIRTUAL_TABLE" + ErrCodeMissingInput = "MISSING_INPUT" + ErrCodeMissingTool = "MISSING_TOOL" + ErrCodeInvalidInput = "INVALID_INPUT" + ErrCodeImportFailed = "IMPORT_FAILED" + ErrCodeVerifyFailed = "VERIFY_FAILED" + ErrCodeNotFound = "NOT_FOUND" + ErrCodePrereqFailed = "PREREQ_FAILED" + ErrCodeLintBlocked = "LINT_BLOCKED" + ErrCodeDestinationConflict = "DESTINATION_CONFLICT" + ErrCodeStatePersistFailed = "STATE_PERSIST_FAILED" +) + +const ( + wranglerMissingRemediation = "Install wrangler, use npx wrangler d1 export, or pass --input if you already have a dump." + pgloaderInstallRemediation = "Install pgloader (brew install pgloader on macOS; see https://pgloader.readthedocs.io/en/latest/install.html for other platforms)" + lintBlockedRemediation = "Fix lint errors or run `pscale import d1 lint` for details; use `import d1 start --dry-run` for a read-only preview" +) + +type MigrationError struct { + Info ErrorInfo +} + +func (e *MigrationError) Error() string { + return e.Info.Message +} + +func migrationErr(err error) (*MigrationError, bool) { + var me *MigrationError + if errors.As(err, &me) { + return me, true + } + return nil, false +} + +func newMigrationError(code, message, remediation string) *MigrationError { + return &MigrationError{ + Info: ErrorInfo{ + Code: code, + Message: message, + Remediation: remediation, + }, + } +} + +func lintBlockedReason(errorCount int) string { + return fmt.Sprintf("lint reported %d error(s); fix or use import d1 lint for details", errorCount) +} + +func errMissingInput(path string) error { + return newMigrationError( + ErrCodeMissingInput, + fmt.Sprintf("input file not found: %s", path), + "Export with wrangler (wrangler d1 export --remote --output ./dump.sql) or pass an existing dump with --input", + ) +} + +func errMissingTool(name, remediation string) error { + return newMigrationError( + ErrCodeMissingTool, + fmt.Sprintf("required tool not found: %s", name), + remediation, + ) +} + +func errExistingImportTables(tables []string) error { + return newMigrationError( + ErrCodeDestinationConflict, + fmt.Sprintf("destination already has tables from this import: %s", strings.Join(tables, ", ")), + "Use a new branch, drop the conflicting tables, or choose a database without overlapping table names before importing", + ) +} + +func errStatePersist(operation string, err error) error { + return newMigrationError( + ErrCodeStatePersistFailed, + fmt.Sprintf("%s succeeded but local migration state could not be saved: %v", operation, err), + "Postgres may already reflect the finished step; re-run status or verify before continuing", + ) +} + +// ErrLintBlocked returns a structured error when lint errors block import. +func ErrLintBlocked(reason string) error { + return newMigrationError(ErrCodeLintBlocked, reason, lintBlockedRemediation) +} diff --git a/internal/import/d1/errors_test.go b/internal/import/d1/errors_test.go new file mode 100644 index 000000000..b2c9a149f --- /dev/null +++ b/internal/import/d1/errors_test.go @@ -0,0 +1,14 @@ +package d1 + +import "testing" + +func requireMigrationErr(t *testing.T, err error, code string) { + t.Helper() + me, ok := migrationErr(err) + if !ok { + t.Fatalf("expected MigrationError, got %T: %v", err, err) + } + if me.Info.Code != code { + t.Fatalf("code = %q, want %q", me.Info.Code, code) + } +} diff --git a/internal/import/d1/identifiers.go b/internal/import/d1/identifiers.go new file mode 100644 index 000000000..cd7c25b3a --- /dev/null +++ b/internal/import/d1/identifiers.go @@ -0,0 +1,60 @@ +package d1 + +import ( + "fmt" + "unicode" +) + +const postgresMaxIdentifierBytes = 63 + +func lintIdentifiers(table TableSchema) []Issue { + var issues []Issue + issues = append(issues, lintIdentifier(table.Name, table.Name, "")...) + for _, col := range table.Columns { + issues = append(issues, lintIdentifier(table.Name, col.Name, col.Name)...) + } + return issues +} + +func lintIdentifier(table, name, column string) []Issue { + var issues []Issue + if len(name) > postgresMaxIdentifierBytes { + target := "table" + if column != "" { + target = "column" + } + issues = append(issues, Issue{ + Code: "IDENTIFIER_TOO_LONG", + Severity: SeverityError, + Table: table, + Column: column, + Message: fmt.Sprintf("%s name %q exceeds PostgreSQL 63-byte identifier limit (%d bytes)", target, name, len(name)), + Remediation: "Rename the " + target + " in SQLite before export, or use quoted identifiers that fit within 63 bytes in PostgreSQL", + }) + } + if hasMixedCaseIdentifier(name) { + issues = append(issues, Issue{ + Code: "MIXED_CASE_IDENTIFIER", + Severity: SeverityWarning, + Table: table, + Column: column, + Message: fmt.Sprintf("identifier %q contains uppercase letters", name), + Remediation: "PostgreSQL folds unquoted identifiers to lowercase; prefer snake_case in D1 exports to avoid case mismatches during import", + }) + } + return issues +} + +func hasMixedCaseIdentifier(name string) bool { + hasUpper := false + hasLower := false + for _, r := range name { + if unicode.IsUpper(r) { + hasUpper = true + } + if unicode.IsLower(r) { + hasLower = true + } + } + return hasUpper && hasLower +} diff --git a/internal/import/d1/identifiers_test.go b/internal/import/d1/identifiers_test.go new file mode 100644 index 000000000..22a27b19b --- /dev/null +++ b/internal/import/d1/identifiers_test.go @@ -0,0 +1,49 @@ +package d1 + +import "testing" + +func TestLintIdentifiers(t *testing.T) { + longName := stringsRepeat("a", 64) + table := TableSchema{ + Name: longName, + Columns: []ColumnSchema{ + {Name: "ok_col", Type: "TEXT"}, + {Name: stringsRepeat("b", 64), Type: "TEXT"}, + {Name: "UserId", Type: "INTEGER"}, + }, + } + + issues := lintIdentifiers(table) + if len(issues) != 3 { + t.Fatalf("expected 3 issues, got %d: %#v", len(issues), issues) + } + if issues[0].Code != "IDENTIFIER_TOO_LONG" || issues[0].Severity != SeverityError { + t.Fatalf("table issue = %#v", issues[0]) + } + if issues[1].Code != "IDENTIFIER_TOO_LONG" || issues[1].Column == "" { + t.Fatalf("column issue = %#v", issues[1]) + } + if issues[2].Code != "MIXED_CASE_IDENTIFIER" || issues[2].Column != "UserId" { + t.Fatalf("mixed case issue = %#v", issues[2]) + } +} + +func TestHasMixedCaseIdentifier(t *testing.T) { + if hasMixedCaseIdentifier("user_id") { + t.Fatal("snake_case should not flag") + } + if hasMixedCaseIdentifier("USER_ID") { + t.Fatal("all caps should not flag") + } + if !hasMixedCaseIdentifier("UserId") { + t.Fatal("mixed case should flag") + } +} + +func stringsRepeat(s string, n int) string { + out := make([]byte, n) + for i := range out { + out[i] = s[0] + } + return string(out) +} diff --git a/internal/import/d1/import.go b/internal/import/d1/import.go new file mode 100644 index 000000000..32ddb053b --- /dev/null +++ b/internal/import/d1/import.go @@ -0,0 +1,646 @@ +package d1 + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "time" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/postgres" + "github.com/planetscale/cli/internal/roleutil" + execabs "golang.org/x/sys/execabs" +) + +// ImportOptions configures D1 import into PlanetScale Postgres. +type ImportOptions struct { + Org string + Database string + Branch string + InputPath string + Method string + MigrationID string + DBName string + DryRun bool + DestURI string // optional override for testing + NotifyAPI NotifyAPIConfig + OnProgress ImportProgressFunc + // PgloaderVerbose emits full pgloader reports to stderr (defaults to false). + PgloaderVerbose bool + notifyBase importNotificationPayload +} + +// ImportClient abstracts PlanetScale API access for import. +type ImportClient interface { + GetDatabase(ctx context.Context, org, database string) (*ps.Database, error) +} + +// DefaultImportClient wraps planetscale client. +type DefaultImportClient struct { + Client *ps.Client +} + +func (c *DefaultImportClient) GetDatabase(ctx context.Context, org, database string) (*ps.Database, error) { + return c.Client.Databases.Get(ctx, &ps.GetDatabaseRequest{ + Organization: org, + Database: database, + }) +} + +// Import loads a D1 SQLite dump into PlanetScale Postgres. +// Pass prepared when the caller already ran PrepareImport (e.g. human confirm flow). +func Import(ctx context.Context, psClient *ps.Client, client ImportClient, opts ImportOptions, prepared *ImportPrepareResult) (result *ImportResult, err error) { + if prepared == nil { + prepared, err = PrepareImport(opts) + if err != nil { + return nil, err + } + } + + opts.MigrationID = prepared.MigrationID + opts.Method = prepared.Method + if opts.InputPath == "" && prepared.Plan != nil { + opts.InputPath = prepared.Plan.InputPath + } else if opts.InputPath != "" { + if normalized, err := NormalizeInputPath(opts.InputPath); err != nil { + return nil, err + } else { + opts.InputPath = normalized + } + } + + result = importResultFromPrepare(prepared, opts.DryRun) + + if !prepared.CanProceed { + return result, ErrLintBlocked(prepared.BlockedReason) + } + + if opts.DryRun { + return result, nil + } + + if _, err := FindPgloader(); err != nil { + return result, err + } + + importStarted := false + importDataLoaded := false + var importStart time.Time + defer func() { + if err != nil && importStarted && !importDataLoaded { + _ = saveImportMigrationState(opts, PhaseFailed, "") + if !opts.DryRun { + payload := notifyPayloadFromImport(opts, result) + payload.DurationMs = time.Since(importStart).Milliseconds() + notifyImportFailure(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, payload, err, nil) + } + } + }() + + importStart = time.Now() + timings := &ImportTimings{} + importStarted = true + opts.notifyBase = notifyPayloadFromImport(opts, result) + NotifyImportEventSync(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, NotifyEventStarting, opts.notifyBase) + + opts.reportProgress(ImportProgress{Stage: ImportStageConnecting}) + + db, err := client.GetDatabase(ctx, opts.Org, opts.Database) + if err != nil { + return result, fmt.Errorf("get database: %w", err) + } + if db.Kind != "postgresql" { + return result, newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("database %s is not PostgreSQL", opts.Database), + "Create a PostgreSQL database branch for D1 migration", + ) + } + + sqlitePath := DefaultSQLitePath(opts.InputPath) + if state, stateErr := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); stateErr == nil { + if err := validateInputPathAgainstState(opts.InputPath, state.InputPath); err != nil { + return result, err + } + if state.SQLitePath != "" { + sqlitePath = state.SQLitePath + } + } + + if shouldPreserveImportProgress(ctx, opts, "") { + if err := saveImportMigrationState(opts, PhaseImporting, ""); err != nil { + return result, err + } + } else if err := resetImportProgress(opts, PhaseImporting, ""); err != nil { + return result, err + } + + sqliteStart := time.Now() + opts.reportProgress(ImportProgress{Stage: ImportStageSQLiteStaging}) + if err := EnsureSQLiteFromDump(ctx, opts.InputPath, sqlitePath); err != nil { + remediation := "Ensure the dump is valid and the host has enough memory and disk for SQLite staging" + if errors.Is(err, context.Canceled) { + remediation = "Import was interrupted; re-run start to resume or start fresh" + } + return result, newMigrationError(ErrCodeImportFailed, err.Error(), remediation) + } + timings.SQLiteStagingMs = time.Since(sqliteStart).Milliseconds() + + destURI, cleanup, err := ResolveDestURI(ctx, psClient, opts) + if err != nil { + return result, err + } + if cleanup != nil { + defer cleanup() + } + + currentUser, err := usernameFromDestURI(destURI) + if err != nil { + return result, err + } + if err := cleanupStaleImportRoles(ctx, psClient, opts, currentUser); err != nil { + return result, err + } + + switch opts.Method { + case MethodPgloader: + if err := importWithPgloader(ctx, opts, destURI, sqlitePath, timings); err != nil { + return result, err + } + case MethodPsql: + if err := importSmall(ctx, opts, destURI, sqlitePath); err != nil { + return result, err + } + default: + return result, newMigrationError(ErrCodeInvalidInput, "unknown import method: "+opts.Method, "Use pgloader (large dumps) or psql (small dumps; data loaded via pgloader)") + } + importDataLoaded = true + + tables, err := ParseDump(opts.InputPath) + if err == nil { + for _, table := range tables { + if !IsORMMetadataTable(table.Name) { + result.TablesLoaded++ + } + } + } + + timings.TotalMs = time.Since(importStart).Milliseconds() + result.Timings = timings + + loadedTables, _ := PgloaderLoadTables(opts.InputPath) + state := &MigrationState{ + MigrationID: opts.MigrationID, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + InputPath: opts.InputPath, + SQLitePath: sqlitePath, + DBName: opts.DBName, + Method: opts.Method, + Phase: PhaseImported, + LoadedTables: loadedTables, + } + if !opts.DryRun { + if err := SaveState(state); err != nil { + return result, errStatePersist("import", err) + } + NotifyImportEventSync(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, NotifyEventImported, notifyPayloadFromImport(opts, result)) + } + + return result, nil +} + +func importWithPgloader(ctx context.Context, opts ImportOptions, destURI, sqlitePath string, timings *ImportTimings) error { + schemaResume, err := importSchemaResumeEnabled(ctx, opts, destURI) + if err != nil { + return err + } + if !schemaResume { + opts.reportProgress(ImportProgress{Stage: ImportStageSchema}) + schemaStart := time.Now() + if err := applyPostgresSchema(ctx, opts, destURI); err != nil { + return err + } + timings.SchemaMs = time.Since(schemaStart).Milliseconds() + } + dataResume, err := importDataResumeEnabled(ctx, opts, destURI) + if err != nil { + return err + } + return loadTablesAndFinalize(ctx, opts, destURI, sqlitePath, timings, dataResume) +} + +// importSmall loads dumps under 1GB: schema via psql, data via pgloader. +func importSmall(ctx context.Context, opts ImportOptions, destURI, sqlitePath string) error { + schemaResume, err := importSchemaResumeEnabled(ctx, opts, destURI) + if err != nil { + return err + } + if !schemaResume { + opts.reportProgress(ImportProgress{Stage: ImportStageSchema}) + if err := applyPostgresSchema(ctx, opts, destURI); err != nil { + return err + } + } + dataResume, err := importDataResumeEnabled(ctx, opts, destURI) + if err != nil { + return err + } + return loadTablesAndFinalize(ctx, opts, destURI, sqlitePath, nil, dataResume) +} + +func importSchemaResumeEnabled(ctx context.Context, opts ImportOptions, destURI string) (bool, error) { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return false, nil + } + if state.Phase != PhaseFailed && state.Phase != PhaseImporting { + return false, nil + } + if !state.SchemaApplied { + return false, nil + } + if destURI == "" { + return false, nil + } + has, err := destHasImportTables(ctx, opts, destURI) + if err != nil { + return false, fmt.Errorf("check import tables for schema resume: %w", err) + } + return has, nil +} + +func importDataResumeEnabled(ctx context.Context, opts ImportOptions, destURI string) (bool, error) { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return false, nil + } + if state.Phase != PhaseFailed && state.Phase != PhaseImporting { + return false, nil + } + if len(state.LoadedTables) == 0 { + return false, nil + } + if destURI == "" { + return false, nil + } + populated, err := populatedLoadedTables(ctx, destURI, state.LoadedTables) + if err != nil { + return false, fmt.Errorf("check populated tables for resume: %w", err) + } + return len(populated) > 0, nil +} + +func shouldPreserveImportProgress(ctx context.Context, opts ImportOptions, destURI string) bool { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return false + } + if state.Phase != PhaseFailed && state.Phase != PhaseImporting { + return false + } + if len(state.LoadedTables) > 0 { + if destURI == "" { + return true + } + has, err := destHasImportTables(ctx, opts, destURI) + if err != nil { + return false + } + return has + } + if state.SchemaApplied { + if destURI == "" { + return true + } + has, err := destHasImportTables(ctx, opts, destURI) + if err != nil { + return false + } + return has + } + return false +} + +func destHasImportTables(ctx context.Context, opts ImportOptions, destURI string) (bool, error) { + tables, err := ParseDump(opts.InputPath) + if err != nil { + return false, err + } + existing, err := existingPublicTables(ctx, destURI, importTableNames(tables)) + if err != nil { + return false, err + } + return len(existing) > 0, nil +} + +func loadTablesAndFinalize(ctx context.Context, opts ImportOptions, destURI, sqlitePath string, timings *ImportTimings, resume bool) error { + loadTables, err := PgloaderLoadTables(opts.InputPath) + if err != nil { + return err + } + + var skipTables []string + if resume { + if state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err == nil && len(state.LoadedTables) > 0 { + populated, err := populatedLoadedTables(ctx, destURI, state.LoadedTables) + if err != nil { + return err + } + skipTables = populated + } + } + + pgTimings, err := RunPgloader(ctx, PgloaderOptions{ + SQLitePath: sqlitePath, + DestURI: destURI, + InputPath: opts.InputPath, + DataOnly: true, + Tables: loadTables, + SkipTables: skipTables, + OnProgress: opts.reportProgress, + PgloaderVerbose: opts.PgloaderVerbose, + OnTableLoaded: func(table string) error { + return appendLoadedTable(opts, table) + }, + }) + if err != nil { + return err + } + if timings != nil { + timings.PgloaderMs = pgTimings.PgloaderMs + timings.TableLoads = pgTimings.TableLoads + } + + opts.reportProgress(ImportProgress{Stage: ImportStageIndexes}) + indexStart := time.Now() + if err := applyPostgresIndexes(ctx, opts, destURI); err != nil { + return err + } + if timings != nil { + timings.IndexBuildMs = time.Since(indexStart).Milliseconds() + } + + seqStart := time.Now() + opts.reportProgress(ImportProgress{Stage: ImportStageSequences}) + if err := ResetImportedSequences(ctx, destURI, opts.InputPath); err != nil { + return err + } + if timings != nil { + timings.SequenceResetMs = time.Since(seqStart).Milliseconds() + } + return nil +} + +// ResolveDestURI creates a short-lived Postgres role and returns a connection string. +func ResolveDestURI(ctx context.Context, psClient *ps.Client, opts ImportOptions) (string, func() error, error) { + if opts.DestURI != "" { + return opts.DestURI, func() error { return nil }, nil + } + if psClient == nil { + return "", nil, fmt.Errorf("planetscale client required for import") + } + + roleName := fmt.Sprintf("d1-import-%d", time.Now().Unix()) + role, err := roleutil.New(ctx, psClient, roleutil.Options{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Name: roleName, + TTL: 2 * time.Hour, + InheritedRoles: []string{"postgres"}, + }) + if err != nil { + return "", nil, fmt.Errorf("create destination role: %w", err) + } + + dbName := opts.DBName + if dbName == "" { + dbName = "postgres" + } + + uri := postgres.BuildConnectionString(&postgres.Config{ + Host: role.Role.AccessHostURL, + Port: 5432, + User: role.Role.Username, + Password: role.Role.Password, + Database: dbName, + SSLMode: "require", + Options: map[string]string{}, + }) + + return uri, func() error { return role.Cleanup(ctx, "postgres") }, nil +} + +// ResetImportedSequences aligns identity sequences with MAX(column) after pgloader import. +// Per-table pgloader runs may leave sequences at their initial value; setval is idempotent. +func ResetImportedSequences(ctx context.Context, destURI, inputPath string) error { + tables, err := ParseDump(inputPath) + if err != nil { + return err + } + + db, err := OpenPostgres(destURI) + if err != nil { + return err + } + defer db.Close() + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !col.AutoIncrement { + continue + } + query := fmt.Sprintf( + `SELECT setval(pg_get_serial_sequence($1, $2), GREATEST(COALESCE((SELECT MAX(%s) FROM %s), 1), 1), true)`, + postgres.QuoteIdentifier(col.Name), + postgres.QuoteIdentifier(table.Name), + ) + if _, err := db.ExecContext(ctx, query, "public."+table.Name, col.Name); err != nil { + return fmt.Errorf("reset sequence %s.%s: %w", table.Name, col.Name, err) + } + } + } + return nil +} + +func appendLoadedTable(opts ImportOptions, table string) error { + return updateMigrationState(opts.Org, opts.Database, opts.Branch, opts.MigrationID, func(state *MigrationState) { + if slices.Contains(state.LoadedTables, table) { + return + } + state.LoadedTables = append(state.LoadedTables, table) + }) +} + +func setSchemaApplied(opts ImportOptions) error { + return updateMigrationState(opts.Org, opts.Database, opts.Branch, opts.MigrationID, func(state *MigrationState) { + state.SchemaApplied = true + }) +} + +func resetImportProgress(opts ImportOptions, phase, sqlitePath string) error { + return updateMigrationState(opts.Org, opts.Database, opts.Branch, opts.MigrationID, func(state *MigrationState) { + state.Phase = phase + state.SchemaApplied = false + state.LoadedTables = nil + if opts.InputPath != "" { + state.InputPath = opts.InputPath + } + if opts.Method != "" { + state.Method = opts.Method + } + if opts.DBName != "" { + state.DBName = opts.DBName + } + if sqlitePath != "" { + state.SQLitePath = sqlitePath + } + }) +} + +func applyPostgresSchema(ctx context.Context, opts ImportOptions, destURI string) error { + tables, err := ParseDump(opts.InputPath) + if err != nil { + return err + } + + importNames := importTableNames(tables) + existing, err := existingPublicTables(ctx, destURI, importNames) + if err != nil { + return err + } + if conflicts := conflictingImportTables(importNames, existing); len(conflicts) > 0 { + return errExistingImportTables(conflicts) + } + + workDir, err := os.MkdirTemp("", "pscale-d1-schema-*") + if err != nil { + return err + } + defer os.RemoveAll(workDir) + + var b strings.Builder + b.WriteString("-- Generated by pscale import d1\n") + b.WriteString("-- Source: ") + b.WriteString(opts.InputPath) + b.WriteString("\n\n") + importSQL, err := buildImportTablesSQL(opts.InputPath, tables) + if err != nil { + return err + } + b.WriteString(importSQL) + + combinedPath := filepath.Join(workDir, fmt.Sprintf("postgres-tables-%s.sql", opts.MigrationID)) + if err := os.WriteFile(combinedPath, []byte(b.String()), 0o600); err != nil { + return err + } + + if err := runPsqlFile(ctx, destURI, combinedPath); err != nil { + return err + } + return setSchemaApplied(opts) +} + +func applyPostgresIndexes(ctx context.Context, opts ImportOptions, destURI string) error { + parts, _, err := ConvertSchemaParts(opts.InputPath) + if err != nil { + return err + } + if strings.TrimSpace(parts.Indexes) == "" { + return nil + } + + workDir, err := os.MkdirTemp("", "pscale-d1-indexes-*") + if err != nil { + return err + } + defer os.RemoveAll(workDir) + + var b strings.Builder + b.WriteString("-- Generated by pscale import d1 (post-load indexes)\n") + fmt.Fprintf(&b, "SET maintenance_work_mem TO '%s';\n", pgloaderIndexMaintenanceWorkMem) + b.WriteString(parts.Indexes) + + indexPath := filepath.Join(workDir, fmt.Sprintf("postgres-indexes-%s.sql", opts.MigrationID)) + if err := os.WriteFile(indexPath, []byte(b.String()), 0o600); err != nil { + return err + } + + return runPsqlFile(ctx, destURI, indexPath) +} + +const ( + connectionRetryAttempts = 4 + connectionRetryBase = 2 * time.Second +) + +func isRetryableConnectionError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "bad connection") || + strings.Contains(msg, "connection reset") || + strings.Contains(msg, "couldn't read") || + strings.Contains(msg, "could not read") || + strings.Contains(msg, "broken pipe") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "no connection") || + strings.Contains(msg, "server closed the connection") || + strings.Contains(msg, "connection timed out") || + strings.Contains(msg, "i/o timeout") +} + +func withConnectionRetry(ctx context.Context, fn func() error) error { + var lastErr error + for attempt := 0; attempt < connectionRetryAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return err + } + lastErr = fn() + if lastErr == nil || !isRetryableConnectionError(lastErr) { + return lastErr + } + if attempt == connectionRetryAttempts-1 { + break + } + delay := connectionRetryBase * time.Duration(1< --input " + filepath.Base(result.InputPath) + " --dry-run", + Reason: "Preview import plan and get a migration ID", + }, + } + if result.ErrorCount == 0 { + steps = append(steps, NextStep{ + Command: "pscale import d1 start --input " + filepath.Base(result.InputPath), + Reason: "Run import after lint passes", + }) + } + return steps +} diff --git a/internal/import/d1/lint_test.go b/internal/import/d1/lint_test.go new file mode 100644 index 000000000..1a01f37f1 --- /dev/null +++ b/internal/import/d1/lint_test.go @@ -0,0 +1,213 @@ +package d1 + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func testFixture(t *testing.T) string { + t.Helper() + return filepath.Join("testdata", "sample_d1_export.sql") +} + +func requireSQLite3(t *testing.T) { + t.Helper() + if _, err := FindSQLite3(); err != nil { + t.Skip("sqlite3 not installed") + } +} + +func TestParseDump(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatalf("ParseDump: %v", err) + } + if len(tables) != 6 { + t.Fatalf("expected 6 tables, got %d", len(tables)) + } + if tables[0].Name != "users" { + t.Fatalf("expected users table first, got %s", tables[0].Name) + } + + var teamMembers *TableSchema + for i := range tables { + if tables[i].Name == "entity_links" { + teamMembers = &tables[i] + break + } + } + if teamMembers == nil { + t.Fatal("expected entity_links table") + } + if len(teamMembers.Constraints) < 2 { + t.Fatalf("expected composite PK and FK constraints, got %v", teamMembers.Constraints) + } +} + +func TestParseIndexes(t *testing.T) { + indexes, err := ParseIndexes(testFixture(t)) + if err != nil { + t.Fatalf("ParseIndexes: %v", err) + } + if len(indexes) != 2 { + t.Fatalf("expected 2 indexes, got %d", len(indexes)) + } + if indexes[0].Name != "idx_users_email" { + t.Fatalf("unexpected first index: %s", indexes[0].Name) + } +} + +func TestLint(t *testing.T) { + result, err := Lint(testFixture(t)) + if err != nil { + t.Fatalf("Lint: %v", err) + } + if result.TableCount != 6 { + t.Fatalf("expected 6 tables, got %d", result.TableCount) + } + if result.ErrorCount != 0 { + t.Fatalf("expected no errors, got %d", result.ErrorCount) + } + if result.WarningCount == 0 { + t.Fatal("expected warnings for autoincrement/boolean columns") + } + + foundAutoincrement := false + foundDrizzle := false + for _, issue := range result.Issues { + if issue.Code == "AUTOINCREMENT" { + foundAutoincrement = true + } + if issue.Code == "DRIZZLE_MIGRATIONS" { + foundDrizzle = true + } + } + if !foundAutoincrement { + t.Fatal("expected AUTOINCREMENT issue") + } + if !foundDrizzle { + t.Fatal("expected DRIZZLE_MIGRATIONS issue") + } +} + +func TestPlan(t *testing.T) { + plan, err := Plan(PlanOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + }) + if err != nil { + t.Fatalf("Plan: %v", err) + } + if plan.MigrationID == "" { + t.Fatal("expected migration id") + } + if len(plan.LoadOrder) != 6 { + t.Fatalf("expected load order length 6, got %d", len(plan.LoadOrder)) + } + if plan.RecommendedMethod == "" { + t.Fatal("expected recommended method") + } +} + +func TestConvertSchema(t *testing.T) { + out := t.TempDir() + "/schema.sql" + count, err := ConvertSchema(testFixture(t), out) + if err != nil { + t.Fatalf("ConvertSchema: %v", err) + } + if count != 4 { + t.Fatalf("expected 4 tables converted, got %d", count) + } + data, err := os.ReadFile(out) + if err != nil { + t.Fatal(err) + } + content := string(data) + checks := []string{ + "GENERATED BY DEFAULT AS IDENTITY", + "BOOLEAN", + "TIMESTAMPTZ", + `FOREIGN KEY ("user_id") REFERENCES "users" ("id")`, + `PRIMARY KEY ("entity_id", "post_id")`, + `"id" UUID PRIMARY KEY`, + `"entity_id" UUID NOT NULL`, + `CREATE INDEX IF NOT EXISTS "idx_users_email"`, + `CREATE UNIQUE INDEX IF NOT EXISTS "idx_entity_links_post"`, + `UNIQUE`, + `"id" INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY`, + } + for _, check := range checks { + if !strings.Contains(content, check) { + t.Fatalf("expected schema to contain %q\n%s", check, content) + } + } + if strings.Contains(content, "__drizzle_migrations") || strings.Contains(content, "_prisma_migrations") { + t.Fatal("ORM metadata tables should be skipped in schema output") + } +} + +func TestCountInsertRows(t *testing.T) { + counts, err := CountInsertRows(testFixture(t)) + if err != nil { + t.Fatalf("CountInsertRows: %v", err) + } + if counts["users"] != 2 { + t.Fatalf("expected 2 user rows, got %d", counts["users"]) + } + if counts["posts"] != 2 { + t.Fatalf("expected 2 post rows, got %d", counts["posts"]) + } +} + +func TestStateStore(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + store, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + state := &MigrationState{ + MigrationID: "test123", + Org: "acme", + Database: "mydb", + Branch: "main", + InputPath: testFixture(t), + Phase: PhasePlanned, + } + if err := store.Save(state); err != nil { + t.Fatal(err) + } + + loaded, err := store.Load("acme", "mydb", "main", "test123") + if err != nil { + t.Fatal(err) + } + if loaded.MigrationID != "test123" { + t.Fatalf("expected test123, got %s", loaded.MigrationID) + } + + if err := store.Delete("acme", "mydb", "main", "test123"); err != nil { + t.Fatal(err) + } +} + +func TestConvertTableConstraint(t *testing.T) { + got := convertTableConstraint("FOREIGN KEY (team_id, user_id) REFERENCES teams(id)") + want := `FOREIGN KEY ("team_id", "user_id") REFERENCES "teams" ("id")` + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestQuoteColumnList(t *testing.T) { + got := quoteColumnList("org_id, slug") + want := `"org_id", "slug"` + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} diff --git a/internal/import/d1/notify.go b/internal/import/d1/notify.go new file mode 100644 index 000000000..9698c0317 --- /dev/null +++ b/internal/import/d1/notify.go @@ -0,0 +1,298 @@ +package d1 + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + ps "github.com/planetscale/planetscale-go/planetscale" +) + +const importNotifyTimeout = 3 * time.Second + +// D1 import Slack notification event names. +const ( + NotifyEventStarting = "starting" + NotifyEventProgress = "progress" + NotifyEventImported = "imported" + NotifyEventVerifying = "verifying" + NotifyEventVerified = "verified" + NotifyEventComplete = "complete" + NotifyEventFailed = "failed" +) + +// NotifyAPIConfig carries the PlanetScale API client for async D1 import notifications. +type NotifyAPIConfig struct { + Client *ps.Client + // Disabled skips notifications (--no-notify). + Disabled bool +} + +// NotifyImportEvent posts a D1 import lifecycle event to api-bb asynchronously. +// Progress updates use this path; lifecycle and failure events should use NotifyImportEventSync. +func NotifyImportEvent(api NotifyAPIConfig, org, database, branch, migrationID, event string, extra importNotificationPayload) { + deliverImportNotification(api, org, database, branch, migrationID, event, extra, false) +} + +// NotifyImportEventSync waits briefly for api-bb to accept the notification. +// Used for lifecycle boundaries and failures so Slack is reported before the CLI exits. +func NotifyImportEventSync(api NotifyAPIConfig, org, database, branch, migrationID, event string, extra importNotificationPayload) { + deliverImportNotification(api, org, database, branch, migrationID, event, extra, true) +} + +func deliverImportNotification(api NotifyAPIConfig, org, database, branch, migrationID, event string, extra importNotificationPayload, wait bool) { + if api.Disabled || api.Client == nil { + return + } + + payload := importNotificationPayload{ + MigrationID: migrationID, + Event: event, + Method: extra.Method, + ExportBytes: extra.ExportBytes, + TableCount: extra.TableCount, + Matched: extra.Matched, + DurationMs: extra.DurationMs, + Error: extra.Error, + Stage: extra.Stage, + Message: extra.Message, + } + if branch != "" { + payload.BranchName = branch + } + + send := func() { + ctx, cancel := context.WithTimeout(context.Background(), importNotifyTimeout) + defer cancel() + _ = postImportNotification(ctx, api, org, database, payload) + } + + if wait { + send() + return + } + + go send() +} + +type importNotificationPayload struct { + BranchName string + MigrationID string + Event string + Method string + ExportBytes int64 + TableCount int + Matched *bool + DurationMs int64 + Error string + Stage string + Message string +} + +func postImportNotification(ctx context.Context, api NotifyAPIConfig, org, database string, payload importNotificationPayload) error { + return api.Client.D1ImportNotifications.Create(ctx, &ps.CreateD1ImportNotificationRequest{ + Organization: org, + Database: database, + BranchName: payload.BranchName, + MigrationID: payload.MigrationID, + Event: payload.Event, + Method: payload.Method, + ExportBytes: payload.ExportBytes, + TableCount: payload.TableCount, + Matched: payload.Matched, + DurationMs: payload.DurationMs, + Error: payload.Error, + Stage: payload.Stage, + Message: payload.Message, + }) +} + +func notifyPayloadFromImport(opts ImportOptions, result *ImportResult) importNotificationPayload { + payload := importNotificationPayload{ + Method: opts.Method, + } + if result != nil { + payload.TableCount = result.TablesLoaded + if result.Timings != nil { + payload.DurationMs = result.Timings.TotalMs + } + } + if info, err := os.Stat(opts.InputPath); err == nil { + payload.ExportBytes = info.Size() + } + return payload +} + +func notifyPayloadFromState(state *MigrationState) importNotificationPayload { + if state == nil { + return importNotificationPayload{} + } + + payload := importNotificationPayload{ + Method: state.Method, + } + if state.InputPath != "" { + if info, err := os.Stat(state.InputPath); err == nil { + payload.ExportBytes = info.Size() + } + } + if n := len(state.LoadedTables); n > 0 { + payload.TableCount = n + } + if !state.CreatedAt.IsZero() && !state.UpdatedAt.IsZero() { + payload.DurationMs = state.UpdatedAt.Sub(state.CreatedAt).Milliseconds() + } + return payload +} + +func notifyPayloadFromVerify(opts VerifyOptions) importNotificationPayload { + payload := importNotificationPayload{} + if state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err == nil { + payload.Method = state.Method + if n := len(state.LoadedTables); n > 0 { + payload.TableCount = n + } + } + if opts.InputPath != "" { + if info, err := os.Stat(opts.InputPath); err == nil { + payload.ExportBytes = info.Size() + } + } + return payload +} + +func notifyImportProgress(api NotifyAPIConfig, org, database, branch, migrationID string, base importNotificationPayload, p ImportProgress) { + if !shouldNotifyProgress(p) { + return + } + payload := base + payload.Stage = p.Stage + payload.Message = formatNotifyProgressMessage(p) + NotifyImportEvent(api, org, database, branch, migrationID, NotifyEventProgress, payload) +} + +func formatNotifyProgressMessage(p ImportProgress) string { + switch p.Stage { + case ImportStagePgloader: + if p.Total > 0 { + return fmt.Sprintf("Loading tables... (%d tables)", p.Total) + } + return "Loading tables..." + case VerifyStageRowCounts: + if p.Total > 0 { + return fmt.Sprintf("Comparing row counts... (%d tables)", p.Total) + } + return "Comparing row counts..." + default: + return FormatProgressMessage(p) + } +} + +func shouldNotifyProgress(p ImportProgress) bool { + switch p.Stage { + case ImportStageConnecting, ImportStageSQLiteStaging, ImportStageSchema, + ImportStageIndexes, ImportStageSequences, + VerifyStageSequences, VerifyStageBoolean, VerifyStageFingerprints, VerifyStageSampleRows: + return true + case ImportStagePgloader: + // One Slack update when table loading begins, not per table. + return p.Current == 1 && p.Total > 0 + case VerifyStageRowCounts: + // One Slack update when row-count verify begins, not per table/source. + return p.Current == 0 && p.Total > 0 + default: + return false + } +} + +// notifyImportFailure posts a failed event with structured MigrationError details. +func notifyImportFailure(api NotifyAPIConfig, org, database, branch, migrationID string, base importNotificationPayload, err error, verifyResult *VerifyResult) { + if err == nil { + return + } + payload := base + payload.Error = formatNotifyError(err, verifyResult) + if verifyResult != nil && !verifyResult.Matched { + matched := false + payload.Matched = &matched + } + NotifyImportEventSync(api, org, database, branch, migrationID, NotifyEventFailed, payload) +} + +func formatNotifyError(err error, verifyResult *VerifyResult) string { + if errors.Is(err, context.Canceled) { + return "[IMPORT_FAILED] import cancelled" + } + if errors.Is(err, context.DeadlineExceeded) { + return "[IMPORT_FAILED] import timed out" + } + for e := err; e != nil; e = errors.Unwrap(e) { + if me, ok := migrationErr(e); ok { + var b strings.Builder + if me.Info.Code != "" { + b.WriteString("[") + b.WriteString(me.Info.Code) + b.WriteString("] ") + } + b.WriteString(me.Info.Message) + if me.Info.Remediation != "" { + b.WriteString("\n") + b.WriteString(me.Info.Remediation) + } + if verifyResult != nil { + if summary := verifyFailureSummary(verifyResult); summary != "" { + b.WriteString("\n") + b.WriteString(summary) + } + } + return b.String() + } + } + if err != nil { + return err.Error() + } + return "" +} + +func verifyFailureSummary(result *VerifyResult) string { + if result == nil { + return "" + } + + var parts []string + for _, table := range result.Tables { + if table.Match { + continue + } + parts = append(parts, fmt.Sprintf("%s: sqlite=%d postgres=%d", table.Table, table.SourceRows, table.DestRows)) + } + for _, check := range result.Checks { + if check.Matched { + continue + } + label := check.Name + if check.Table != "" { + label = check.Table + if check.Column != "" { + label += "." + check.Column + } + } + if check.Message != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", label, check.Message)) + } else { + parts = append(parts, label) + } + } + + if len(parts) == 0 { + return "" + } + const maxParts = 8 + if len(parts) > maxParts { + return strings.Join(parts[:maxParts], "; ") + fmt.Sprintf("; ... and %d more", len(parts)-maxParts) + } + return strings.Join(parts, "; ") +} diff --git a/internal/import/d1/notify_test.go b/internal/import/d1/notify_test.go new file mode 100644 index 000000000..60d9ed9df --- /dev/null +++ b/internal/import/d1/notify_test.go @@ -0,0 +1,356 @@ +package d1 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + ps "github.com/planetscale/planetscale-go/planetscale" +) + +func testNotifyClient(t *testing.T, baseURL string) *ps.Client { + t.Helper() + + client, err := ps.NewClient( + ps.WithBaseURL(baseURL), + ps.WithAccessToken("token"), + ) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + return client +} + +func TestShouldNotifyProgressMajorStages(t *testing.T) { + for _, stage := range []string{ + ImportStageConnecting, + ImportStageSQLiteStaging, + ImportStageSchema, + ImportStageIndexes, + ImportStageSequences, + } { + if !shouldNotifyProgress(ImportProgress{Stage: stage}) { + t.Fatalf("expected stage %q to notify", stage) + } + } +} + +func TestShouldNotifyProgressPgloaderTables(t *testing.T) { + if !shouldNotifyProgress(ImportProgress{Stage: ImportStagePgloader, Current: 1, Total: 19, Detail: "users"}) { + t.Fatal("expected first pgloader table to notify") + } + for _, current := range []int{0, 2, 19} { + if shouldNotifyProgress(ImportProgress{Stage: ImportStagePgloader, Current: current, Total: 19, Detail: "users"}) { + t.Fatalf("expected pgloader table %d to skip slack notification", current) + } + } +} + +func TestShouldNotifyProgressRowCounts(t *testing.T) { + if !shouldNotifyProgress(ImportProgress{Stage: VerifyStageRowCounts, Total: 19}) { + t.Fatal("expected row count stage start to notify") + } + for _, current := range []int{1, 2, 19} { + if shouldNotifyProgress(ImportProgress{Stage: VerifyStageRowCounts, Current: current, Total: 19, Detail: "users (sqlite)"}) { + t.Fatalf("expected row count progress %d to skip slack notification", current) + } + } +} + +func TestFormatNotifyProgressMessageAggregates(t *testing.T) { + got := formatNotifyProgressMessage(ImportProgress{Stage: ImportStagePgloader, Current: 1, Total: 19, Detail: "users"}) + want := "Loading tables... (19 tables)" + if got != want { + t.Fatalf("pgloader message = %q, want %q", got, want) + } + got = formatNotifyProgressMessage(ImportProgress{Stage: VerifyStageRowCounts, Total: 19}) + want = "Comparing row counts... (19 tables)" + if got != want { + t.Fatalf("row counts message = %q, want %q", got, want) + } +} + +func TestFormatProgressMessageSQLiteStaging(t *testing.T) { + got := FormatProgressMessage(ImportProgress{Stage: ImportStageSQLiteStaging}) + want := "Staging SQLite database from export..." + if got != want { + t.Fatalf("message = %q, want %q", got, want) + } +} + +func TestShouldNotifyProgressUnknownStage(t *testing.T) { + if shouldNotifyProgress(ImportProgress{Stage: "custom_stage", Current: 1, Detail: "working"}) { + t.Fatal("expected unknown stage to skip slack notification") + } +} + +func TestImportProgressPgloaderUsesReportProgress(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + opts := ImportOptions{ + Org: "org", + Database: "db", + Branch: "main", + MigrationID: "abc123", + Method: MethodPgloader, + NotifyAPI: NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, + OnProgress: func(p ImportProgress) { + if p.Stage != ImportStagePgloader { + t.Fatalf("OnProgress stage = %q, want %q", p.Stage, ImportStagePgloader) + } + }, + } + + opts.reportProgress(ImportProgress{ + Stage: ImportStagePgloader, + Current: 1, + Total: 3, + Detail: "users", + }) + + deadline := time.After(2 * time.Second) + for calls.Load() == 0 { + select { + case <-deadline: + t.Fatal("expected pgloader progress Slack notification") + default: + time.Sleep(10 * time.Millisecond) + } + } +} + +func TestNotifyImportEventSync_CompletesBeforeReturn(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + start := time.Now() + NotifyImportEventSync(NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", "main", "abc123", NotifyEventFailed, importNotificationPayload{ + Error: "boom", + }) + if elapsed := time.Since(start); elapsed < 100*time.Millisecond { + t.Fatalf("NotifyImportEventSync returned in %v, expected to wait for request", elapsed) + } + if calls.Load() != 1 { + t.Fatalf("calls = %d, want 1", calls.Load()) + } +} + +func TestNotifyImportEvent_FireAndForget(t *testing.T) { + var calls atomic.Int32 + done := make(chan struct{}, 1) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + done <- struct{}{} + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + NotifyImportEvent(NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", "main", "abc123", "start", importNotificationPayload{ + Method: "pgloader", + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("expected notification request") + } + + if calls.Load() != 1 { + t.Fatalf("calls = %d, want 1", calls.Load()) + } +} + +func TestNotifyImportEvent_SkipsWhenDisabled(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("unexpected notification request") + })) + defer srv.Close() + + NotifyImportEvent(NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + Disabled: true, + }, "org", "db", "main", "abc123", "start", importNotificationPayload{}) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + <-ctx.Done() +} + +func TestNotifyImportEvent_SkipsWithoutClient(t *testing.T) { + NotifyImportEvent(NotifyAPIConfig{}, "org", "db", "main", "abc123", "start", importNotificationPayload{}) +} + +func TestNotifyImportEvent_DoesNotFailImportOnAPIError(t *testing.T) { + done := make(chan struct{}, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + done <- struct{}{} + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + NotifyImportEvent(NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", "main", "abc123", "failed", importNotificationPayload{ + Error: "boom", + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("expected notification request") + } +} + +func TestPostImportNotification_ProgressSendsStageAndMessage(t *testing.T) { + var body map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal body: %v", err) + } + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + err := postImportNotification(context.Background(), NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", importNotificationPayload{ + MigrationID: "abc123", + Event: NotifyEventProgress, + Method: "pgloader", + Stage: ImportStageSQLiteStaging, + Message: "Staging SQLite database from export...", + }) + if err != nil { + t.Fatalf("postImportNotification: %v", err) + } + if body["event"] != "progress" { + t.Fatalf("event = %v, want progress", body["event"]) + } + if body["method"] != "pgloader" { + t.Fatalf("method = %v, want pgloader", body["method"]) + } + if body["stage"] != ImportStageSQLiteStaging { + t.Fatalf("stage = %v, want %s", body["stage"], ImportStageSQLiteStaging) + } + if body["message"] != "Staging SQLite database from export..." { + t.Fatalf("message = %v", body["message"]) + } + if _, ok := body["error"]; ok { + t.Fatalf("error = %v, want omitted for progress status text", body["error"]) + } +} + +func TestPostImportNotification_UsesInternalRoute(t *testing.T) { + var path string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path = r.URL.Path + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + err := postImportNotification(context.Background(), NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", importNotificationPayload{ + MigrationID: "abc123", + Event: "start", + }) + if err != nil { + t.Fatalf("postImportNotification: %v", err) + } + if path != "/internal/organizations/org/databases/db/d1-import-notifications" { + t.Fatalf("path = %q", path) + } +} + +func TestFormatNotifyError_MigrationError(t *testing.T) { + err := fmt.Errorf("pgloader table team_members: %w", newMigrationError( + ErrCodeImportFailed, + `pgloader copied 0 rows into "team_members" (expected 700 from dump)`, + pgloaderNoRowsRemediation, + )) + + got := formatNotifyError(err, nil) + want := "[IMPORT_FAILED] pgloader copied 0 rows into \"team_members\" (expected 700 from dump)\n" + pgloaderNoRowsRemediation + if got != want { + t.Fatalf("formatNotifyError() = %q, want %q", got, want) + } +} + +func TestVerifyFailureSummary(t *testing.T) { + summary := verifyFailureSummary(&VerifyResult{ + Tables: []TableVerifyResult{ + {Table: "team_members", SourceRows: 700, DestRows: 0, Match: false}, + {Table: "organizations", SourceRows: 28, DestRows: 28, Match: true}, + }, + }) + if summary != "team_members: sqlite=700 postgres=0" { + t.Fatalf("summary = %q", summary) + } +} + +func TestNotifyImportFailure_SendsFailedEvent(t *testing.T) { + done := make(chan struct{}, 1) + var body map[string]any + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { done <- struct{}{} }() + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal body: %v", err) + } + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + notifyImportFailure(NotifyAPIConfig{ + Client: testNotifyClient(t, srv.URL), + }, "org", "db", "main", "abc123", importNotificationPayload{}, newMigrationError( + ErrCodeImportFailed, + "pgloader matched 0 source tables", + pgloaderNoRowsRemediation, + ), nil) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("expected notification request") + } + + if body["event"] != "failed" { + t.Fatalf("event = %v, want failed", body["event"]) + } + if !strings.Contains(body["error"].(string), "[IMPORT_FAILED]") { + t.Fatalf("error = %v, want IMPORT_FAILED prefix", body["error"]) + } +} diff --git a/internal/import/d1/orm_metadata.go b/internal/import/d1/orm_metadata.go new file mode 100644 index 000000000..42df451cc --- /dev/null +++ b/internal/import/d1/orm_metadata.go @@ -0,0 +1,156 @@ +package d1 + +import ( + "strings" +) + +type ormMetadataRule struct { + code string + orm string + remediation string + match func(table string) bool +} + +var ormMetadataRules = []ormMetadataRule{ + { + code: "DRIZZLE_MIGRATIONS", + orm: "Drizzle", + remediation: "After import, baseline Drizzle on Postgres (e.g. drizzle-kit push or a fresh migrations folder); " + + "do not rely on SQLite __drizzle_migrations history", + match: func(table string) bool { + return strings.HasPrefix(strings.ToLower(table), "__drizzle") + }, + }, + { + code: "PRISMA_MIGRATIONS", + orm: "Prisma", + remediation: "After import, baseline Prisma on Postgres (e.g. prisma db pull then prisma migrate resolve / new initial migration); " + + "do not import _prisma_migrations from SQLite", + match: matchTableName("_prisma_migrations"), + }, + { + code: "KNEX_MIGRATIONS", + orm: "Knex", + remediation: "After import, re-baseline Knex migration history on Postgres; knex_migrations from SQLite is not valid on Postgres", + match: matchAnyTableName("knex_migrations", "knex_migrations_lock"), + }, + { + code: "SEQUELIZE_META", + orm: "Sequelize", + remediation: "After import, re-baseline Sequelize migration history on Postgres; SequelizeMeta from SQLite is not valid on Postgres", + match: matchTableName("sequelizemeta"), + }, + { + code: "RAILS_MIGRATIONS", + orm: "Rails ActiveRecord", + remediation: "After import, re-baseline Rails schema_migrations on Postgres; SQLite migration versions do not transfer cleanly", + match: matchAnyTableName("schema_migrations", "ar_internal_metadata"), + }, + { + code: "FLYWAY_MIGRATIONS", + orm: "Flyway", + remediation: "After import, baseline Flyway on Postgres; flyway_schema_history from SQLite must not be reused", + match: matchTableName("flyway_schema_history"), + }, + { + code: "LIQUIBASE_MIGRATIONS", + orm: "Liquibase", + remediation: "After import, baseline Liquibase on Postgres; databasechangelog tables from SQLite must not be reused", + match: matchAnyTableName("databasechangelog", "databasechangeloglock"), + }, + { + code: "DJANGO_MIGRATIONS", + orm: "Django", + remediation: "After import, run django migrate --fake-initial or otherwise baseline django_migrations on Postgres", + match: matchTableName("django_migrations"), + }, + { + code: "ALEMBIC_VERSION", + orm: "Alembic", + remediation: "After import, stamp Alembic to the correct Postgres revision; alembic_version from SQLite is not portable", + match: matchTableName("alembic_version"), + }, + { + code: "TYPEORM_METADATA", + orm: "TypeORM", + remediation: "After import, baseline TypeORM migrations on Postgres; typeorm_metadata from SQLite is not valid on Postgres", + match: matchTableName("typeorm_metadata"), + }, + { + code: "GOOSE_MIGRATIONS", + orm: "Goose", + remediation: "After import, re-baseline Goose version table on Postgres; goose_db_version from SQLite is not portable", + match: matchTableName("goose_db_version"), + }, +} + +func matchTableName(name string) func(string) bool { + lower := strings.ToLower(name) + return func(table string) bool { + return strings.ToLower(table) == lower + } +} + +func matchAnyTableName(names ...string) func(string) bool { + set := make(map[string]struct{}, len(names)) + for _, name := range names { + set[strings.ToLower(name)] = struct{}{} + } + return func(table string) bool { + _, ok := set[strings.ToLower(table)] + return ok + } +} + +// IsORMMetadataTable reports whether a table holds ORM/framework migration bookkeeping +// that should not be imported into Postgres. +func IsORMMetadataTable(name string) bool { + return ORMMetadataRule(name) != nil +} + +// ORMMetadataRule returns the matching ORM metadata rule, if any. +func ORMMetadataRule(name string) *ormMetadataRule { + for i := range ormMetadataRules { + if ormMetadataRules[i].match(name) { + return &ormMetadataRules[i] + } + } + return nil +} + +func lintORMMetadata(table TableSchema) []Issue { + rule := ORMMetadataRule(table.Name) + if rule == nil { + return nil + } + issues := []Issue{{ + Code: rule.code, + Severity: SeverityInfo, + Table: table.Name, + Message: rule.orm + " migration metadata table detected", + Remediation: rule.remediation, + }} + if strings.EqualFold(table.Name, "schema_migrations") && !looksLikeRailsSchemaMigrations(table) { + issues = append(issues, Issue{ + Code: "SCHEMA_MIGRATIONS_NAME_COLLISION", + Severity: SeverityWarning, + Table: table.Name, + Message: "table name matches Rails schema_migrations but column layout does not", + Remediation: "If this is application data, rename the table before import; ORM metadata skip will exclude it from Postgres", + }) + } + return issues +} + +func looksLikeRailsSchemaMigrations(table TableSchema) bool { + if len(table.Columns) != 1 { + return false + } + col := table.Columns[0] + name := strings.ToLower(col.Name) + if name != "version" { + return false + } + t := strings.ToUpper(col.Type) + return strings.Contains(t, "CHAR") || strings.Contains(t, "TEXT") || t == "" +} diff --git a/internal/import/d1/orm_metadata_test.go b/internal/import/d1/orm_metadata_test.go new file mode 100644 index 000000000..e9496fac7 --- /dev/null +++ b/internal/import/d1/orm_metadata_test.go @@ -0,0 +1,62 @@ +package d1 + +import "testing" + +func TestIsORMMetadataTable(t *testing.T) { + tests := []struct { + table string + want bool + code string + }{ + {"__drizzle_migrations", true, "DRIZZLE_MIGRATIONS"}, + {"__drizzle_migrations_journal", true, "DRIZZLE_MIGRATIONS"}, + {"_prisma_migrations", true, "PRISMA_MIGRATIONS"}, + {"knex_migrations", true, "KNEX_MIGRATIONS"}, + {"knex_migrations_lock", true, "KNEX_MIGRATIONS"}, + {"SequelizeMeta", true, "SEQUELIZE_META"}, + {"schema_migrations", true, "RAILS_MIGRATIONS"}, + {"ar_internal_metadata", true, "RAILS_MIGRATIONS"}, + {"flyway_schema_history", true, "FLYWAY_MIGRATIONS"}, + {"databasechangelog", true, "LIQUIBASE_MIGRATIONS"}, + {"django_migrations", true, "DJANGO_MIGRATIONS"}, + {"alembic_version", true, "ALEMBIC_VERSION"}, + {"typeorm_metadata", true, "TYPEORM_METADATA"}, + {"goose_db_version", true, "GOOSE_MIGRATIONS"}, + {"users", false, ""}, + {"migrations", false, ""}, + {"organizations", false, ""}, + } + + for _, tc := range tests { + got := IsORMMetadataTable(tc.table) + if got != tc.want { + t.Fatalf("IsORMMetadataTable(%q) = %v, want %v", tc.table, got, tc.want) + } + if tc.want { + rule := ORMMetadataRule(tc.table) + if rule == nil || rule.code != tc.code { + t.Fatalf("ORMMetadataRule(%q) = %v, want code %q", tc.table, rule, tc.code) + } + } + } +} + +func TestLintORMMetadataTables(t *testing.T) { + result, err := Lint(testFixture(t)) + if err != nil { + t.Fatalf("Lint: %v", err) + } + + found := map[string]bool{} + for _, issue := range result.Issues { + if issue.Code == "DRIZZLE_MIGRATIONS" || issue.Code == "PRISMA_MIGRATIONS" { + found[issue.Code] = true + } + } + if !found["DRIZZLE_MIGRATIONS"] { + t.Fatal("expected DRIZZLE_MIGRATIONS lint issue") + } + if !found["PRISMA_MIGRATIONS"] { + t.Fatal("expected PRISMA_MIGRATIONS lint issue") + } +} diff --git a/internal/import/d1/output.go b/internal/import/d1/output.go new file mode 100644 index 000000000..5db62995f --- /dev/null +++ b/internal/import/d1/output.go @@ -0,0 +1,290 @@ +package d1 + +import ( + "fmt" + "time" + + "github.com/planetscale/cli/internal/printer" +) + +// PrintHumanResponse writes a human-readable success response via the shared printer. +func PrintHumanResponse(p *printer.Printer, resp Response) { + p.Printf("Status: %s", resp.Status) + if resp.Phase != "" { + p.Printf(" (%s)", resp.Phase) + } else if resp.Command != "" { + p.Printf(" (%s)", resp.Command) + } + p.Println() + + if resp.MigrationID != "" { + p.Printf("Migration ID: %s\n", resp.MigrationID) + } + + printHumanData(p, resp.Command, resp.Data) + + if resp.Error != nil { + p.Printf("\nError [%s]: %s\n", resp.Error.Code, resp.Error.Message) + if resp.Error.Remediation != "" { + p.Printf("%s\n", resp.Error.Remediation) + } + } + + if len(resp.Issues) > 0 { + p.Printf("\nIssues (%d):\n", len(resp.Issues)) + for _, issue := range resp.Issues { + loc := issue.Table + if issue.Column != "" { + loc += "." + issue.Column + } + p.Printf(" [%s] %s %s: %s\n", issue.Severity, issue.Code, loc, issue.Remediation) + } + } + + if len(resp.NextSteps) > 0 { + p.Println("\nNext steps:") + for _, step := range resp.NextSteps { + if step.Command != "" { + p.Printf(" - %s (%s)\n", step.Command, step.Reason) + } else if step.Tool != "" { + p.Printf(" - %s: %s\n", step.Tool, step.Reason) + } else { + p.Printf(" - %s\n", step.Reason) + } + } + } +} + +func printVerifyResultHuman(p *printer.Printer, r VerifyResult) { + matched := "no" + if r.Matched { + matched = "yes" + } + p.Printf("\nMatched: %s\n", matched) + + for _, table := range r.Tables { + if table.Match { + continue + } + p.Printf(" row count mismatch %s: sqlite=%d postgres=%d\n", table.Table, table.SourceRows, table.DestRows) + } + for _, check := range r.Checks { + if check.Matched { + continue + } + label := check.Name + if check.Table != "" { + label = check.Table + if check.Column != "" { + label += "." + check.Column + } + } + if check.Message != "" { + p.Printf(" check failed %s: %s\n", label, check.Message) + } else { + p.Printf(" check failed %s\n", label) + } + } +} + +func printMigrationStateHuman(p *printer.Printer, r MigrationState) { + if r.Method != "" { + p.Printf("Method: %s\n", r.Method) + } + if len(r.LoadedTables) > 0 { + p.Printf("Tables loaded: %d\n", len(r.LoadedTables)) + } + if r.InputPath != "" { + p.Printf("Input: %s\n", r.InputPath) + } + if !r.UpdatedAt.IsZero() { + p.Printf("Updated: %s\n", r.UpdatedAt.Format(time.RFC3339)) + } +} + +func printImportResultHuman(p *printer.Printer, r ImportResult) { + p.Printf("\nMethod: %s", r.Method) + if r.DryRun { + p.Print(" (dry run)") + } + p.Println() + if r.Plan != nil { + sizeMB := float64(r.Plan.EstimatedSizeBytes) / (1024 * 1024) + p.Printf("Plan: %d tables, %.1f MB estimated\n", len(r.Plan.Tables), sizeMB) + } + if r.TablesLoaded > 0 { + p.Printf("Tables loaded: %d\n", r.TablesLoaded) + } + if r.Timings != nil && r.Timings.TotalMs > 0 { + p.Printf("Total time: %.1fs\n", float64(r.Timings.TotalMs)/1000) + } +} + +func printDoctorResultHuman(p *printer.Printer, r DoctorResult) { + p.Println("\nChecks:") + for _, c := range r.Checks { + line := fmt.Sprintf(" %s: %s", c.Name, c.Status) + if c.Version != "" { + line += fmt.Sprintf(" (%s)", c.Version) + } + p.Println(line) + } + p.Printf("Ready: %v\n", r.Ready) +} + +func printHumanData(p *printer.Printer, command string, data any) { + if data == nil { + return + } + + switch command { + case "doctor": + switch r := data.(type) { + case DoctorResult: + printDoctorResultHuman(p, r) + case *DoctorResult: + if r != nil { + printDoctorResultHuman(p, *r) + } + } + case "lint": + switch r := data.(type) { + case LintResult: + p.Printf("\nTables: %d | Errors: %d | Warnings: %d\n", r.TableCount, r.ErrorCount, r.WarningCount) + case *LintResult: + if r != nil { + p.Printf("\nTables: %d | Errors: %d | Warnings: %d\n", r.TableCount, r.ErrorCount, r.WarningCount) + } + } + case "start": + switch r := data.(type) { + case ImportResult: + printImportResultHuman(p, r) + case *ImportResult: + if r != nil { + printImportResultHuman(p, *r) + } + } + case "verify": + switch r := data.(type) { + case VerifyResult: + printVerifyResultHuman(p, r) + case *VerifyResult: + if r != nil { + printVerifyResultHuman(p, *r) + } + } + case "status": + switch r := data.(type) { + case MigrationState: + printMigrationStateHuman(p, r) + case *MigrationState: + if r != nil { + printMigrationStateHuman(p, *r) + } + } + case "convert-schema": + if m, ok := data.(map[string]any); ok { + p.Println() + p.Printf(" Input: %v\n", m["input"]) + p.Printf(" Output: %v\n", m["output"]) + p.Printf(" Tables: %v\n", m["table_count"]) + } + case "complete": + switch r := data.(type) { + case CompleteResult: + p.Println() + p.Printf(" Migration ID: %s\n", r.MigrationID) + p.Printf(" Status: %s\n", r.Status) + printCompleteReminderHuman(p, r) + case *CompleteResult: + if r != nil { + p.Println() + p.Printf(" Migration ID: %s\n", r.MigrationID) + p.Printf(" Status: %s\n", r.Status) + printCompleteReminderHuman(p, *r) + } + case map[string]string: + p.Println() + p.Printf(" Migration ID: %s\n", r["migration_id"]) + p.Printf(" Status: %s\n", r["status"]) + } + } +} + +// StatusResponse builds the status command envelope. +func StatusResponse(state *MigrationState) Response { + var next []NextStep + if state != nil { + next = StatusNextSteps(state) + } + resp := OKResponse("status", state, next) + if state != nil { + resp.MigrationID = state.MigrationID + resp.Phase = state.Phase + } + return resp +} + +// OKResponse builds a success response. +func OKResponse(command string, data any, next []NextStep) Response { + return Response{ + Status: "ok", + Command: command, + Data: data, + NextSteps: next, + } +} + +// ErrorResponse builds an error response from an error. +func ErrorResponse(command string, err error) Response { + resp := Response{ + Status: "error", + Command: command, + } + if me, ok := migrationErr(err); ok { + resp.Error = &me.Info + } else { + resp.Error = &ErrorInfo{ + Code: ErrCodeImportFailed, + Message: err.Error(), + } + } + return resp +} + +// DoctorResponse builds the doctor command envelope, including check details when not ready. +func DoctorResponse(result *DoctorResult) Response { + resp := OKResponse("doctor", result, DoctorNextSteps(result)) + if result != nil && !result.Ready { + resp.Status = "error" + if err := DoctorReadinessError(result); err != nil { + if me, ok := migrationErr(err); ok { + resp.Error = &me.Info + } else { + resp.Error = &ErrorInfo{ + Code: ErrCodePrereqFailed, + Message: err.Error(), + } + } + } + } + return resp +} + +// LintResponse builds the lint command envelope with status derived from issue severity. +func LintResponse(result *LintResult) Response { + resp := OKResponse("lint", result, LintNextSteps(result)) + resp.Issues = result.Issues + if result.ErrorCount > 0 { + resp.Status = "error" + resp.Error = &ErrorInfo{ + Code: ErrCodeLintBlocked, + Message: lintBlockedReason(result.ErrorCount), + Remediation: lintBlockedRemediation, + } + } else if result.WarningCount > 0 { + resp.Status = "warning" + } + return resp +} diff --git a/internal/import/d1/output_test.go b/internal/import/d1/output_test.go new file mode 100644 index 000000000..639619564 --- /dev/null +++ b/internal/import/d1/output_test.go @@ -0,0 +1,190 @@ +package d1 + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/planetscale/cli/internal/printer" +) + +func TestLintResponseSetsErrorEnvelope(t *testing.T) { + result := &LintResult{ + InputPath: "/tmp/export.sql", + TableCount: 1, + ErrorCount: 1, + WarningCount: 2, + Issues: []Issue{{ + Code: "VIRTUAL_TABLE", + Severity: SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + } + + resp := LintResponse(result) + if resp.Status != "error" { + t.Fatalf("status = %q, want error", resp.Status) + } + if resp.Error == nil { + t.Fatal("expected structured error") + } + if resp.Error.Code != ErrCodeLintBlocked { + t.Fatalf("error code = %q, want %q", resp.Error.Code, ErrCodeLintBlocked) + } + if len(resp.Issues) != 1 { + t.Fatalf("issues = %d, want 1", len(resp.Issues)) + } +} + +func TestDoctorResponseIncludesChecksWhenNotReady(t *testing.T) { + result := &DoctorResult{ + Ready: false, + Checks: []DoctorCheck{{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader not found", + Remediation: pgloaderInstallRemediation, + }}, + } + + resp := DoctorResponse(result) + if resp.Status != "error" { + t.Fatalf("status = %q, want error", resp.Status) + } + if resp.Error == nil || resp.Error.Code != ErrCodePrereqFailed { + t.Fatalf("error = %#v, want prereq_failed", resp.Error) + } + data, ok := resp.Data.(*DoctorResult) + if !ok || data == nil { + t.Fatalf("data = %T, want *DoctorResult", resp.Data) + } + if len(data.Checks) != 1 || data.Checks[0].Name != "pgloader" { + t.Fatalf("checks = %#v", data.Checks) + } +} + +func TestPrintHumanResponseDoctorFailureIncludesChecks(t *testing.T) { + resp := DoctorResponse(&DoctorResult{ + Ready: false, + Checks: []DoctorCheck{{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader not found", + }}, + }) + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintHumanResponse(p, resp) + + out := buf.String() + for _, want := range []string{ + "Status: error", + "pgloader: fail", + "Ready: false", + ErrCodePrereqFailed, + } { + if !strings.Contains(out, want) { + t.Fatalf("output missing %q:\n%s", want, out) + } + } +} + +func TestPrintHumanResponseIncludesLintIssuesOnError(t *testing.T) { + resp := LintResponse(&LintResult{ + TableCount: 1, + ErrorCount: 1, + Issues: []Issue{{ + Code: "VIRTUAL_TABLE", + Severity: SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + }) + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintHumanResponse(p, resp) + + out := buf.String() + for _, want := range []string{ + "Status: error", + "Errors: 1", + "[error] VIRTUAL_TABLE", + "Virtual tables are not supported", + ErrCodeLintBlocked, + } { + if !strings.Contains(out, want) { + t.Fatalf("output missing %q:\n%s", want, out) + } + } +} + +func TestPrintHumanResponseStatusShowsMigrationPhase(t *testing.T) { + state := &MigrationState{ + MigrationID: "abc123", + Database: "mydb", + Branch: "main", + Phase: PhaseImported, + Method: "pgloader", + InputPath: "/tmp/export.sql", + LoadedTables: []string{"users", "posts"}, + UpdatedAt: time.Date(2026, 6, 29, 12, 0, 0, 0, time.UTC), + } + + resp := StatusResponse(state) + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintHumanResponse(p, resp) + + out := buf.String() + for _, want := range []string{ + "Status: ok (imported)", + "Migration ID: abc123", + "Method: pgloader", + "Tables loaded: 2", + "Input: /tmp/export.sql", + "Updated: 2026-06-29T12:00:00Z", + "Next steps:", + "pscale import d1 verify mydb --migration-id abc123 --input \"/tmp/export.sql\"", + } { + if !strings.Contains(out, want) { + t.Fatalf("output missing %q:\n%s", want, out) + } + } +} + +func TestStatusResponseSetsPhase(t *testing.T) { + state := &MigrationState{ + MigrationID: "abc123", + Database: "mydb", + Branch: "main", + Phase: PhaseVerified, + } + + resp := StatusResponse(state) + if len(resp.NextSteps) != 1 { + t.Fatalf("next_steps = %d, want 1", len(resp.NextSteps)) + } + if !strings.Contains(resp.NextSteps[0].Command, "import d1 complete mydb") { + t.Fatalf("next step = %q, want complete command", resp.NextSteps[0].Command) + } + if resp.Phase != PhaseVerified { + t.Fatalf("phase = %q, want %q", resp.Phase, PhaseVerified) + } + if resp.Command != "status" { + t.Fatalf("command = %q, want status", resp.Command) + } + if resp.MigrationID != "abc123" { + t.Fatalf("migration_id = %q, want abc123", resp.MigrationID) + } +} diff --git a/internal/import/d1/parse.go b/internal/import/d1/parse.go new file mode 100644 index 000000000..4a674bd92 --- /dev/null +++ b/internal/import/d1/parse.go @@ -0,0 +1,665 @@ +package d1 + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" +) + +var ( + createTableRe = regexp.MustCompile(`(?is)^CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(`) + virtualTableRe = regexp.MustCompile(`(?is)^CREATE\s+VIRTUAL\s+TABLE`) + autoincrementRe = regexp.MustCompile(`(?i)AUTOINCREMENT`) + columnUniqueRe = regexp.MustCompile(`(?i)\bUNIQUE\b`) + insertRe = regexp.MustCompile(`(?is)^INSERT\s+INTO\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) +) + +// TableSchema holds parsed SQLite table metadata from a dump file. +type TableSchema struct { + Name string + Columns []ColumnSchema + Constraints []string + RawDDL string +} + +// ColumnSchema holds parsed column metadata. +type ColumnSchema struct { + Name string + Type string + PrimaryKey bool + AutoIncrement bool + NotNull bool + Unique bool + DefaultValue string + ForeignKey string +} + +// ParseDump reads a SQLite SQL dump and extracts table definitions. +func ParseDump(path string) ([]TableSchema, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return nil, err + } + f, err := os.Open(clean) + if err != nil { + return nil, err + } + defer f.Close() + + var tables []TableSchema + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var current *TableSchema + var ddlLines []string + parenDepth := 0 + + flush := func() { + if current == nil { + return + } + current.RawDDL = strings.Join(ddlLines, "\n") + current.Columns, current.Constraints = parseTableBody(current.RawDDL) + tables = append(tables, *current) + current = nil + ddlLines = nil + parenDepth = 0 + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + + if virtualTableRe.MatchString(line) { + return nil, newMigrationError( + ErrCodeVirtualTable, + "dump contains CREATE VIRTUAL TABLE statements", + "Remove or recreate FTS5/virtual tables manually in Postgres after migration", + ) + } + + if current == nil { + m := createTableRe.FindStringSubmatch(line) + if m == nil { + continue + } + name := firstNonEmpty(m[1], m[2], m[3], m[4]) + current = &TableSchema{Name: name} + ddlLines = append(ddlLines, line) + parenDepth += strings.Count(line, "(") - strings.Count(line, ")") + if parenDepth <= 0 && strings.HasSuffix(line, ";") { + flush() + } + continue + } + + ddlLines = append(ddlLines, line) + parenDepth += strings.Count(line, "(") - strings.Count(line, ")") + if parenDepth <= 0 && strings.HasSuffix(line, ";") { + flush() + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read dump: %w", err) + } + flush() + + if len(tables) == 0 { + return nil, newMigrationError( + ErrCodeInvalidInput, + "no CREATE TABLE statements found in dump", + "Ensure the input is a wrangler d1 export SQL file with schema definitions", + ) + } + + return tables, nil +} + +func parseTableBody(ddl string) ([]ColumnSchema, []string) { + start := strings.Index(ddl, "(") + end := strings.LastIndex(ddl, ")") + if start < 0 || end <= start { + return nil, nil + } + body := stripSQLComments(ddl[start+1 : end]) + parts := splitColumnDefs(body) + cols := make([]ColumnSchema, 0, len(parts)) + var constraints []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if isTableConstraint(part) { + constraints = append(constraints, part) + continue + } + col := parseColumn(part) + if col.Name != "" { + cols = append(cols, col) + } + } + return cols, constraints +} + +func parseColumn(def string) ColumnSchema { + def = strings.TrimSpace(def) + if def == "" { + return ColumnSchema{} + } + + def = strings.TrimSuffix(def, ",") + + name, rest := parseColumnNameAndRest(def) + if name == "" { + return ColumnSchema{} + } + + colType := firstToken(rest) + col := ColumnSchema{ + Name: name, + Type: colType, + } + + constraints := restAfterFirstToken(rest) + if idx := indexSQLKeyword(constraints, "DEFAULT"); idx >= 0 { + afterDefault := strings.TrimSpace(constraints[idx+len("DEFAULT"):]) + col.DefaultValue = trimDefaultClause(afterDefault) + trailing := strings.TrimSpace(afterDefault[len(col.DefaultValue):]) + constraints = strings.TrimSpace(constraints[:idx]) + if trailing != "" { + constraints = strings.TrimSpace(constraints + " " + trailing) + } + } + constraints = stripCheckClauses(constraints) + + if indexSQLKeyword(constraints, "NOT NULL") >= 0 { + col.NotNull = true + } + if indexSQLKeyword(constraints, "PRIMARY KEY") >= 0 { + col.PrimaryKey = true + } + if columnUniqueRe.MatchString(constraints) { + col.Unique = true + } + if autoincrementRe.MatchString(rest) { + col.AutoIncrement = true + } + if indexSQLKeyword(rest, "REFERENCES") >= 0 { + col.ForeignKey = referencesClause(rest) + } + + return col +} + +// ParseIndexes extracts CREATE INDEX statements from a SQLite dump. +func ParseIndexes(path string) ([]IndexSchema, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return nil, err + } + f, err := os.Open(clean) + if err != nil { + return nil, err + } + defer f.Close() + + var indexes []IndexSchema + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var stmt strings.Builder + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + if stmt.Len() > 0 { + stmt.WriteByte(' ') + } + stmt.WriteString(line) + if !strings.HasSuffix(line, ";") { + continue + } + full := stmt.String() + stmt.Reset() + + if !strings.HasPrefix(strings.ToUpper(full), "CREATE") { + continue + } + upper := strings.ToUpper(full) + if !strings.Contains(upper, " INDEX ") { + continue + } + m := createIndexRe.FindStringSubmatch(full) + if m == nil { + continue + } + indexes = append(indexes, IndexSchema{ + Name: firstNonEmpty(m[2], m[3], m[4], m[5]), + Table: firstNonEmpty(m[6], m[7], m[8], m[9]), + Unique: strings.TrimSpace(m[1]) != "", + Columns: m[10], + RawDDL: full, + }) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read dump indexes: %w", err) + } + return indexes, nil +} + +func splitColumnDefs(body string) []string { + var parts []string + var current strings.Builder + depth := 0 + for _, r := range body { + switch r { + case '(': + depth++ + current.WriteRune(r) + case ')': + depth-- + current.WriteRune(r) + case ',': + if depth == 0 { + parts = append(parts, current.String()) + current.Reset() + continue + } + current.WriteRune(r) + default: + current.WriteRune(r) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +// CountInsertRows estimates row counts per table from INSERT statements. +func CountInsertRows(path string) (map[string]int, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + counts := make(map[string]int) + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var pendingTable string + var pendingSQL strings.Builder + + flush := func() { + if pendingTable == "" { + return + } + sql := pendingSQL.String() + rows := countInsertValueGroups(sql) + if rows == 0 { + rows = 1 + } + counts[pendingTable] += rows + pendingTable = "" + pendingSQL.Reset() + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + + m := insertRe.FindStringSubmatch(line) + if m != nil { + flush() + pendingTable = firstNonEmpty(m[1], m[2], m[3], m[4]) + pendingSQL.WriteString(line) + if strings.HasSuffix(line, ";") { + flush() + } + continue + } + + if pendingTable != "" { + pendingSQL.WriteString(" ") + pendingSQL.WriteString(line) + if strings.HasSuffix(line, ";") { + flush() + } + } + } + flush() + + if err := scanner.Err(); err != nil { + return nil, err + } + return counts, nil +} + +// FileSize returns the size of a file in bytes. +func FileSize(path string) (int64, error) { + info, err := os.Stat(path) + if err != nil { + return 0, err + } + return info.Size(), nil +} + +func countInsertValueGroups(line string) int { + _, groups, ok := parseInsertColumnsAndValues(line) + if !ok || len(groups) == 0 { + return 0 + } + return len(groups) +} + +func firstNonEmpty(vals ...string) string { + for _, v := range vals { + if v != "" { + return v + } + } + return "" +} + +// stripSQLComments removes -- line and /* block */ comments outside quoted strings. +func stripSQLComments(s string) string { + var b strings.Builder + b.Grow(len(s)) + + inSingle := false + inDouble := false + + for i := 0; i < len(s); i++ { + c := s[i] + + if inSingle { + b.WriteByte(c) + if c == '\'' { + if i+1 < len(s) && s[i+1] == '\'' { + b.WriteByte(s[i+1]) + i++ + continue + } + inSingle = false + } + continue + } + if inDouble { + b.WriteByte(c) + if c == '"' { + if i+1 < len(s) && s[i+1] == '"' { + b.WriteByte(s[i+1]) + i++ + continue + } + inDouble = false + } + continue + } + + switch c { + case '\'': + inSingle = true + b.WriteByte(c) + case '"': + inDouble = true + b.WriteByte(c) + case '-': + if i+1 < len(s) && s[i+1] == '-' { + i += 2 + for i < len(s) && s[i] != '\n' { + i++ + } + continue + } + b.WriteByte(c) + case '/': + if i+1 < len(s) && s[i+1] == '*' { + i += 2 + for i+1 < len(s) && (s[i] != '*' || s[i+1] != '/') { + i++ + } + if i+1 < len(s) { + i++ + } + continue + } + b.WriteByte(c) + default: + b.WriteByte(c) + } + } + + return b.String() +} + +func parseColumnNameAndRest(def string) (name, rest string) { + def = strings.TrimSpace(def) + def = strings.TrimSuffix(def, ",") + if def == "" { + return "", "" + } + + switch def[0] { + case '"': + end := 1 + var raw strings.Builder + for end < len(def) { + if def[end] == '"' { + if end+1 < len(def) && def[end+1] == '"' { + raw.WriteByte('"') + end += 2 + continue + } + return raw.String(), strings.TrimSpace(def[end+1:]) + } + raw.WriteByte(def[end]) + end++ + } + return "", def + case '[': + end := strings.Index(def, "]") + if end <= 1 { + return "", def + } + return def[1:end], strings.TrimSpace(def[end+1:]) + case '`': + end := strings.Index(def[1:], "`") + if end < 0 { + return "", def + } + return def[1 : end+1], strings.TrimSpace(def[end+2:]) + case '\'': + end := 1 + var raw strings.Builder + for end < len(def) { + if def[end] == '\'' { + if end+1 < len(def) && def[end+1] == '\'' { + raw.WriteByte('\'') + end += 2 + continue + } + return raw.String(), strings.TrimSpace(def[end+1:]) + } + raw.WriteByte(def[end]) + end++ + } + return "", def + default: + i := 0 + for i < len(def) && !isIdentBreak(def[i]) { + i++ + } + if i == 0 { + return "", def + } + return def[:i], strings.TrimSpace(def[i:]) + } +} + +func trimDefaultClause(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimSuffix(s, ",") + stopPatterns := []string{ + " NOT NULL", + " NULL", + " UNIQUE", + " PRIMARY KEY", + " REFERENCES", + " CHECK", + " COLLATE", + " GENERATED", + } + best := len(s) + upper := strings.ToUpper(s) + for _, pat := range stopPatterns { + if i := indexOutsideQuotes(upper, pat); i >= 0 && i < best { + best = i + } + } + if best < len(s) { + s = strings.TrimSpace(s[:best]) + } + return strings.TrimSuffix(strings.TrimSpace(s), ",") +} + +func restAfterFirstToken(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + i := 0 + for i < len(s) && !isIdentBreak(s[i]) { + i++ + } + return strings.TrimSpace(s[i:]) +} + +func indexOutsideQuotes(s, pattern string) int { + if pattern == "" { + return -1 + } + inQuote := byte(0) + for i := 0; i+len(pattern) <= len(s); i++ { + switch { + case inQuote != 0: + if s[i] == inQuote { + inQuote = 0 + } + case s[i] == '\'' || s[i] == '"': + inQuote = s[i] + case strings.EqualFold(s[i:i+len(pattern)], pattern): + return i + } + } + return -1 +} + +func indexSQLKeyword(s, keyword string) int { + if s == "" || keyword == "" { + return -1 + } + upper := strings.ToUpper(s) + kw := strings.ToUpper(keyword) + for i := 0; i+len(kw) <= len(upper); i++ { + if upper[i:i+len(kw)] != kw { + continue + } + if i > 0 && isSQLIdentChar(upper[i-1]) { + continue + } + end := i + len(kw) + if end < len(upper) && isSQLIdentChar(upper[end]) { + continue + } + return i + } + return -1 +} + +func isSQLIdentChar(c byte) bool { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_' +} + +func stripCheckClauses(s string) string { + upper := strings.ToUpper(s) + var out strings.Builder + for i := 0; i < len(s); { + if strings.HasPrefix(upper[i:], "CHECK") && (i+5 == len(s) || !isSQLIdentChar(upper[i+5])) { + j := i + 5 + for j < len(s) && (s[j] == ' ' || s[j] == '\t') { + j++ + } + if j < len(s) && s[j] == '(' { + if end, ok := matchingParenEnd(s, j); ok { + i = end + 1 + continue + } + } + } + out.WriteByte(s[i]) + i++ + } + return out.String() +} + +func matchingParenEnd(s string, open int) (int, bool) { + if open >= len(s) || s[open] != '(' { + return 0, false + } + depth := 0 + inQuote := byte(0) + for i := open; i < len(s); i++ { + c := s[i] + if inQuote != 0 { + if c == inQuote && (i == 0 || s[i-1] != '\\') { + inQuote = 0 + } + continue + } + switch c { + case '\'', '"', '`': + inQuote = c + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + return i, true + } + } + } + return 0, false +} + +func isIdentBreak(c byte) bool { + switch c { + case ' ', '\t', '\n', '\r', '(', ')', ',': + return true + default: + return false + } +} + +func firstToken(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + i := 0 + for i < len(s) && !isIdentBreak(s[i]) { + i++ + } + return strings.ToUpper(s[:i]) +} diff --git a/internal/import/d1/parse_test.go b/internal/import/d1/parse_test.go new file mode 100644 index 000000000..2561016a9 --- /dev/null +++ b/internal/import/d1/parse_test.go @@ -0,0 +1,60 @@ +package d1 + +import "testing" + +func TestParseColumnDefaultBeforeNotNull(t *testing.T) { + col := parseColumn("active INTEGER DEFAULT 1 NOT NULL") + if col.DefaultValue != "1" { + t.Fatalf("default = %q, want 1", col.DefaultValue) + } + if !col.NotNull { + t.Fatal("expected NOT NULL") + } +} + +func TestParseColumnDefaultStringNotNullNotConstraint(t *testing.T) { + col := parseColumn("status TEXT DEFAULT 'value NOT NULL'") + if col.DefaultValue != "'value NOT NULL'" { + t.Fatalf("default = %q, want quoted literal", col.DefaultValue) + } + if col.NotNull { + t.Fatal("NOT NULL inside default string must not set column constraint") + } +} + +func TestParseColumnCheckNotNullNotConstraint(t *testing.T) { + col := parseColumn("status TEXT CHECK (status IS NOT NULL)") + if col.NotNull { + t.Fatal("NOT NULL inside CHECK must not set column constraint") + } +} + +func TestTrimDefaultClause(t *testing.T) { + cases := map[string]string{ + "1 NOT NULL": "1", + "'draft' NOT NULL UNIQUE": "'draft'", + "CURRENT_TIMESTAMP": "CURRENT_TIMESTAMP", + } + for in, want := range cases { + if got := trimDefaultClause(in); got != want { + t.Fatalf("trimDefaultClause(%q) = %q, want %q", in, got, want) + } + } +} + +func TestParseColumnUniqueConstraint(t *testing.T) { + col := parseColumn("email TEXT NOT NULL UNIQUE") + if !col.Unique { + t.Fatal("expected column-level UNIQUE constraint") + } + + col = parseColumn("unique_token TEXT NOT NULL") + if col.Unique { + t.Fatalf("identifier unique_token should not be treated as UNIQUE constraint") + } + + col = parseColumn("unique_id INTEGER PRIMARY KEY") + if col.Unique { + t.Fatalf("identifier unique_id should not be treated as UNIQUE constraint") + } +} diff --git a/internal/import/d1/path.go b/internal/import/d1/path.go new file mode 100644 index 000000000..d926df356 --- /dev/null +++ b/internal/import/d1/path.go @@ -0,0 +1,81 @@ +package d1 + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// ValidateInputPath ensures a user-supplied path is safe to read. +func ValidateInputPath(path string) (string, error) { + if path == "" { + return "", newMigrationError(ErrCodeMissingInput, "input path is required", "Pass --input with a D1 SQL export file") + } + if strings.ContainsAny(path, "\x00\n\r;") { + return "", newMigrationError(ErrCodeInvalidInput, "invalid characters in input path", "Use a simple file path without newlines or semicolons") + } + + clean := filepath.Clean(path) + info, err := os.Stat(clean) + if err != nil { + if os.IsNotExist(err) { + return "", errMissingInput(clean) + } + return "", fmt.Errorf("stat input: %w", err) + } + if info.IsDir() { + return "", newMigrationError(ErrCodeInvalidInput, "input path is a directory", "Pass a .sql export file path") + } + return clean, nil +} + +// NormalizeInputPath validates path and returns an absolute path for stable state comparisons. +func NormalizeInputPath(path string) (string, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return "", err + } + abs, err := filepath.Abs(clean) + if err != nil { + return clean, nil + } + return abs, nil +} + +func normalizePathForCompare(path string) string { + if path == "" { + return "" + } + abs, err := filepath.Abs(path) + if err != nil { + abs = path + } + eval, err := filepath.EvalSymlinks(abs) + if err != nil { + return filepath.Clean(abs) + } + return eval +} + +func validateInputPathAgainstState(provided, saved string) error { + if provided == "" || saved == "" { + return nil + } + if normalizePathForCompare(provided) != normalizePathForCompare(saved) { + return newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("input path %q does not match migration state %q", provided, saved), + "Use the same --input as the original import or omit --input to use saved state", + ) + } + return nil +} + +// DefaultSQLitePath returns a sqlite path adjacent to the dump. +func DefaultSQLitePath(dumpPath string) string { + base := filepath.Base(dumpPath) + ext := filepath.Ext(base) + name := base[:len(base)-len(ext)] + return filepath.Join(filepath.Dir(dumpPath), name+".sqlite") +} diff --git a/internal/import/d1/path_test.go b/internal/import/d1/path_test.go new file mode 100644 index 000000000..9f3e94e10 --- /dev/null +++ b/internal/import/d1/path_test.go @@ -0,0 +1,45 @@ +package d1 + +import ( + "os" + "path/filepath" + "testing" +) + +func TestValidateInputPathAgainstStateEquivalentPaths(t *testing.T) { + dir := t.TempDir() + dump := filepath.Join(dir, "export.sql") + if err := os.WriteFile(dump, []byte("SELECT 1;\n"), 0o600); err != nil { + t.Fatalf("write dump: %v", err) + } + + rel := "./export.sql" + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + if err := validateInputPathAgainstState(rel, dump); err != nil { + t.Fatalf("expected equivalent paths to match: %v", err) + } +} + +func TestNormalizeInputPathReturnsAbsolute(t *testing.T) { + dir := t.TempDir() + dump := filepath.Join(dir, "export.sql") + if err := os.WriteFile(dump, []byte("SELECT 1;\n"), 0o600); err != nil { + t.Fatalf("write dump: %v", err) + } + + got, err := NormalizeInputPath(dump) + if err != nil { + t.Fatalf("NormalizeInputPath: %v", err) + } + if !filepath.IsAbs(got) { + t.Fatalf("path = %q, want absolute", got) + } +} diff --git a/internal/import/d1/pgloader.go b/internal/import/d1/pgloader.go new file mode 100644 index 000000000..3166df069 --- /dev/null +++ b/internal/import/d1/pgloader.go @@ -0,0 +1,513 @@ +package d1 + +import ( + "context" + _ "embed" + "fmt" + "os" + "path/filepath" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +//go:embed pgloader_transforms.lisp +var pgloaderTransformsLisp string + +const ( + pgloaderBatchSize = "20 MB" + pgloaderDynamicSpace = "4096" // MB per pgloader process (SBCL heap cap) + + pgloaderLargeTableRowThreshold = 100_000 + + // Fast profile: small/medium tables after indexes are deferred. + pgloaderFastPrefetchRows = 25000 + pgloaderFastBatchRows = 25000 + pgloaderFastWorkers = 8 + pgloaderFastConcurrency = 2 + + // Conservative profile: wide rows / large tables (e.g. attachments). + pgloaderSlowPrefetchRows = 5000 + pgloaderSlowBatchRows = 10000 + pgloaderSlowWorkers = 2 + pgloaderSlowConcurrency = 1 + + pgloaderLoadWorkMem = "256MB" + pgloaderLoadMaintenanceWorkMem = "512MB" + pgloaderIndexMaintenanceWorkMem = "2GB" + + pgloaderNoRowsRemediation = "Check pgloader stderr for table filter or cast errors; re-run import d1 start after fixing the dump or CLI" +) + +var ( + pgloaderSummaryErrorRe = regexp.MustCompile(`(?m)^\|\s+(\d+)\s+\|`) + pgloaderFetchMetaDataRe = regexp.MustCompile(`(?m)^\s*fetch meta data\s+\d+\s+(\d+)`) +) + +// PgloaderOptions configures pgloader execution. +type PgloaderOptions struct { + SQLitePath string + DestURI string + InputPath string // dump path for column-level CAST rules + WorkDir string + DryRun bool + DataOnly bool + // Tables loads one table per pgloader invocation when set (recommended for + // large databases — avoids SBCL heap exhaustion from whole-catalog planning). + Tables []string + // SkipTables skips tables already loaded during a resumed import. + SkipTables []string + // OnTableLoaded is called after each table load succeeds (for resume checkpoints). + OnTableLoaded func(table string) error + // OnProgress reports per-table load progress. + OnProgress ImportProgressFunc + // PgloaderVerbose writes full pgloader output to stderr after each table. + PgloaderVerbose bool +} + +type pgloaderMemoryProfile struct { + prefetchRows int + batchRows int + workers int + concurrency int +} + +func pgloaderProfileForTable(rowCount int) pgloaderMemoryProfile { + if rowCount >= pgloaderLargeTableRowThreshold { + return pgloaderMemoryProfile{ + prefetchRows: pgloaderSlowPrefetchRows, + batchRows: pgloaderSlowBatchRows, + workers: pgloaderSlowWorkers, + concurrency: pgloaderSlowConcurrency, + } + } + return pgloaderMemoryProfile{ + prefetchRows: pgloaderFastPrefetchRows, + batchRows: pgloaderFastBatchRows, + workers: pgloaderFastWorkers, + concurrency: pgloaderFastConcurrency, + } +} + +// PgloaderLoadTables returns non-ORM tables in FK-safe load order. +func PgloaderLoadTables(inputPath string) ([]string, error) { + tables, err := ParseDump(inputPath) + if err != nil { + return nil, err + } + ordered := topologicalLoadOrder(tables) + out := make([]string, 0, len(ordered)) + for _, name := range ordered { + if !IsORMMetadataTable(name) { + out = append(out, name) + } + } + return out, nil +} + +// RunPgloader loads SQLite into PostgreSQL using pgloader. +func RunPgloader(ctx context.Context, opts PgloaderOptions) (ImportTimings, error) { + var timings ImportTimings + pgloader, err := FindPgloader() + if err != nil { + return timings, err + } + + if opts.WorkDir == "" { + opts.WorkDir, err = os.MkdirTemp("", "pscale-d1-pgloader-*") + if err != nil { + return timings, err + } + defer os.RemoveAll(opts.WorkDir) + } + + tables := opts.Tables + tableSchemas, err := ParseDump(opts.InputPath) + if err != nil { + return timings, err + } + tableByName := make(map[string]TableSchema, len(tableSchemas)) + for _, t := range tableSchemas { + tableByName[t.Name] = t + } + + rowCounts, err := sqliteStagingRowCounts(ctx, opts.SQLitePath, tables) + if err != nil { + return timings, fmt.Errorf("count sqlite staging rows: %w", err) + } + coerceCtx, err := BuildTypeCoercionContext(opts.InputPath, tableSchemas) + if err != nil { + return timings, err + } + + if len(tables) == 0 { + pgStart := time.Now() + if err := runPgloaderScript(ctx, pgloader, opts, pgloaderScriptConfig{ + dataOnly: opts.DataOnly, + resetSequences: true, + profile: pgloaderProfileForTable(0), + }, TableSchema{}, tableSchemas, 0, coerceCtx); err != nil { + return timings, err + } + timings.PgloaderMs = time.Since(pgStart).Milliseconds() + return timings, nil + } + + pgStart := time.Now() + totalTables := 0 + for _, name := range tables { + if !slices.Contains(opts.SkipTables, name) { + totalTables++ + } + } + loaded := 0 + for _, name := range tables { + if slices.Contains(opts.SkipTables, name) { + continue + } + table, ok := tableByName[name] + if !ok { + return timings, fmt.Errorf("pgloader table %s: not found in dump schema", name) + } + if opts.OnProgress != nil { + opts.OnProgress(ImportProgress{ + Stage: ImportStagePgloader, + Current: loaded + 1, + Total: totalTables, + Detail: name, + }) + } + profile := pgloaderProfileForTable(int(rowCounts[name])) + tableStart := time.Now() + if err := runPgloaderScript(ctx, pgloader, opts, pgloaderScriptConfig{ + dataOnly: opts.DataOnly, + tableName: name, + resetSequences: true, + profile: profile, + }, table, tableSchemas, rowCounts[name], coerceCtx); err != nil { + return timings, fmt.Errorf("pgloader table %s: %w", name, err) + } + timings.TableLoads = append(timings.TableLoads, TableLoadTiming{ + Table: name, + Ms: time.Since(tableStart).Milliseconds(), + }) + if opts.OnTableLoaded != nil { + if err := opts.OnTableLoaded(name); err != nil { + return timings, err + } + } + loaded++ + } + timings.PgloaderMs = time.Since(pgStart).Milliseconds() + return timings, nil +} + +type pgloaderScriptConfig struct { + dataOnly bool + tableName string + resetSequences bool + profile pgloaderMemoryProfile +} + +func runPgloaderScript(ctx context.Context, pgloader string, opts PgloaderOptions, cfg pgloaderScriptConfig, table TableSchema, allTables []TableSchema, expectedRows int64, coerceCtx *TypeCoercionContext) error { + loadFile := filepath.Join(opts.WorkDir, "load.load") + if cfg.tableName != "" { + loadFile = filepath.Join(opts.WorkDir, "load-"+cfg.tableName+".load") + } + castTables := allTables + if table.Name != "" { + castTables = []TableSchema{table} + } + content := buildPgloaderScript(opts.SQLitePath, opts.DestURI, cfg, castTables, allTables, coerceCtx) + if err := os.WriteFile(loadFile, []byte(content), 0o600); err != nil { + return err + } + + if opts.DryRun { + return nil + } + + transformsFile := filepath.Join(opts.WorkDir, "transforms.lisp") + if err := os.WriteFile(transformsFile, []byte(pgloaderTransformsLisp), 0o600); err != nil { + return err + } + + var out []byte + err := withConnectionRetry(ctx, func() error { + cmd := execabs.CommandContext(ctx, pgloader, "--load-lisp-file", transformsFile, loadFile) + cmd.Env = append(os.Environ(), + "SBCL_OPTIONS=--dynamic-space-size "+pgloaderDynamicSpace, + ) + var runErr error + out, runErr = cmd.CombinedOutput() + if runErr != nil { + return fmt.Errorf("pgloader: %w: %s", runErr, string(out)) + } + return nil + }) + output := string(out) + if err != nil { + emitPgloaderOutput(opts, output, true) + return fmt.Errorf("pgloader failed: %w: %s", err, output) + } + if strings.Contains(output, "FATAL") || strings.Contains(output, "KABOOM") || + strings.Contains(output, "ERROR Error while formatting") || + strings.Contains(output, "ERROR The value") || + strings.Contains(output, "Heap exhausted") || + pgloaderHadErrors(output) { + emitPgloaderOutput(opts, output, true) + return fmt.Errorf("pgloader failed: %s", output) + } + if cfg.tableName != "" { + if err := validatePgloaderTableLoad(output, cfg.tableName, expectedRows); err != nil { + emitPgloaderOutput(opts, output, true) + return err + } + } + emitPgloaderOutput(opts, output, false) + return nil +} + +func emitPgloaderOutput(opts PgloaderOptions, output string, force bool) { + if output == "" { + return + } + if force || opts.PgloaderVerbose { + fmt.Fprint(os.Stderr, output) + } +} + +// pgloaderHadErrors inspects pgloader output for failures that do not set exit code. +func pgloaderHadErrors(output string) bool { + if strings.Contains(output, "Database error") || + strings.Contains(output, "INSUFFICIENT-PRIVILEGE") || + strings.Contains(output, "must be owner of table") { + return true + } + for _, match := range pgloaderSummaryErrorRe.FindAllStringSubmatch(output, -1) { + if len(match) < 2 { + continue + } + if match[1] != "0" { + return true + } + } + return false +} + +// pgloaderFetchMetaDataTableCount returns how many source tables pgloader matched, or -1 if absent. +func pgloaderFetchMetaDataTableCount(output string) int { + matches := pgloaderFetchMetaDataRe.FindAllStringSubmatch(output, -1) + if len(matches) == 0 { + return -1 + } + n, err := strconv.Atoi(matches[len(matches)-1][1]) + if err != nil { + return -1 + } + return n +} + +// pgloaderRowsCopied parses the pgloader report summary row count for table, if present. +func pgloaderRowsCopied(output, table string) (int64, bool) { + re := regexp.MustCompile(`(?m)^\s*` + regexp.QuoteMeta(table) + `\s+\d+\s+(\d+)\s+`) + m := re.FindStringSubmatch(output) + if len(m) < 2 { + return 0, false + } + n, err := strconv.ParseInt(m[1], 10, 64) + if err != nil { + return 0, false + } + return n, true +} + +func validatePgloaderTableLoad(output, table string, expectedRows int64) error { + metaCount := pgloaderFetchMetaDataTableCount(output) + if metaCount == 0 { + msg := fmt.Sprintf("pgloader matched 0 source tables for %q", table) + if expectedRows > 0 { + msg = fmt.Sprintf("pgloader matched 0 source tables for %q (expected %d rows from staged SQLite)", table, expectedRows) + } + return newMigrationError(ErrCodeImportFailed, msg, pgloaderNoRowsRemediation) + } + + rows, found := pgloaderRowsCopied(output, table) + if !found { + return newMigrationError( + ErrCodeImportFailed, + fmt.Sprintf("pgloader summary missing row count for %q", table), + pgloaderNoRowsRemediation, + ) + } + if rows != expectedRows { + return newMigrationError( + ErrCodeImportFailed, + fmt.Sprintf("pgloader copied %d rows into %q (expected %d from staged SQLite)", rows, table, expectedRows), + pgloaderNoRowsRemediation, + ) + } + + return nil +} + +func buildPgloaderScript(sqlitePath, destURI string, cfg pgloaderScriptConfig, castTables, allTables []TableSchema, coerceCtx *TypeCoercionContext) string { + absSQLite, _ := filepath.Abs(sqlitePath) + src := "sqlite:///" + strings.ReplaceAll(absSQLite, " ", "%20") + target := destURI + if parsed, err := postgres.ParseConnectionURI(destURI); err == nil { + target = postgres.BuildConnectionURI(parsed) + } + + profile := cfg.profile + if profile.workers == 0 { + profile = pgloaderProfileForTable(0) + } + + var b strings.Builder + b.WriteString("LOAD DATABASE\n") + fmt.Fprintf(&b, " FROM %s\n", src) + fmt.Fprintf(&b, " INTO %s\n", target) + b.WriteString("\n") + + if cfg.dataOnly { + b.WriteString(" WITH data only, create no tables, create no indexes, truncate, disable triggers,\n") + if cfg.resetSequences { + b.WriteString(" reset sequences,\n") + } else { + b.WriteString(" reset no sequences,\n") + } + fmt.Fprintf(&b, " workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency) + fmt.Fprintf(&b, " batch rows = %d,\n", profile.batchRows) + fmt.Fprintf(&b, " batch size = %s,\n", pgloaderBatchSize) + fmt.Fprintf(&b, " prefetch rows = %d\n", profile.prefetchRows) + } else { + b.WriteString(" WITH include drop, create tables, create indexes, reset sequences,\n") + fmt.Fprintf(&b, " workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency) + fmt.Fprintf(&b, " batch rows = %d,\n", profile.batchRows) + fmt.Fprintf(&b, " batch size = %s,\n", pgloaderBatchSize) + fmt.Fprintf(&b, " prefetch rows = %d\n", profile.prefetchRows) + } + + if cfg.tableName != "" { + b.WriteString("\n") + tableNames := tableNames(allTables) + fmt.Fprintf(&b, " INCLUDING ONLY TABLE NAMES%s\n", pgloaderTableNameFilter(cfg.tableName, tableNames)) + } + + appendPgloaderCasts(&b, castTables, allTables, coerceCtx) + + b.WriteString("\n") + fmt.Fprintf(&b, " SET work_mem to '%s', maintenance_work_mem to '%s', synchronous_commit to 'off';\n", + pgloaderLoadWorkMem, pgloaderLoadMaintenanceWorkMem) + return b.String() +} + +func appendPgloaderCasts(b *strings.Builder, castTables, allTables []TableSchema, coerceCtx *TypeCoercionContext) { + var rules []string + for _, table := range castTables { + for _, col := range table.Columns { + pgType := sqliteTypeToPostgres(col, table, allTables, coerceCtx) + ref := fmt.Sprintf("column %s.%s", table.Name, col.Name) + switch pgType { + case "BOOLEAN": + rules = append(rules, ref+" to boolean using sqlite-int-to-boolean") + case "TIMESTAMPTZ": + rules = append(rules, ref+" to timestamptz using sqlite-timestamp-to-timestamp") + case "JSONB": + rules = append(rules, ref+" to jsonb using sqlite-text-to-jsonb") + case "UUID": + rules = append(rules, ref+" to uuid using sqlite-text-to-uuid") + } + } + } + if len(rules) == 0 { + return + } + b.WriteString("\n CAST ") + for i, rule := range rules { + if i > 0 { + b.WriteString(",\n ") + } else { + b.WriteString("\n ") + } + b.WriteString(rule) + } +} + +// pgloaderTableNameFilter returns a pgloader INCLUDING ONLY ... LIKE filter for one table. +// pgloader 3.6.x accepts LIKE 'name' but does not parse ESCAPE clauses, so names with +// LIKE metacharacters add EXCLUDING filters for other tables that would false-match. +func pgloaderTableNameFilter(name string, allTableNames []string) string { + var b strings.Builder + fmt.Fprintf(&b, " LIKE '%s'", escapePgloaderQuote(name)) + if !strings.ContainsAny(name, "_%") { + return b.String() + } + for _, other := range allTableNames { + if other == name { + continue + } + if sqlLikeMatch(name, other) { + fmt.Fprintf(&b, "\n EXCLUDING TABLE NAMES LIKE '%s'", escapePgloaderQuote(other)) + } + } + return b.String() +} + +func sqliteStagingRowCounts(ctx context.Context, sqlitePath string, tables []string) (map[string]int64, error) { + if sqlitePath == "" { + if len(tables) == 0 { + return map[string]int64{}, nil + } + return nil, fmt.Errorf("sqlite staging path required for per-table pgloader validation") + } + if len(tables) == 0 { + return map[string]int64{}, nil + } + return CountSQLiteRows(ctx, sqlitePath, tables) +} + +func tableNames(tables []TableSchema) []string { + names := make([]string, 0, len(tables)) + for _, table := range tables { + names = append(names, table.Name) + } + return names +} + +func sqlLikeMatch(pattern, s string) bool { + m, n := len(pattern), len(s) + dp := make([][]bool, m+1) + for i := range dp { + dp[i] = make([]bool, n+1) + } + dp[0][0] = true + for i := 1; i <= m; i++ { + if pattern[i-1] == '%' { + dp[i][0] = dp[i-1][0] + } + } + for i := 1; i <= m; i++ { + for j := 1; j <= n; j++ { + switch pattern[i-1] { + case '%': + dp[i][j] = dp[i-1][j] || dp[i][j-1] + case '_': + dp[i][j] = dp[i-1][j-1] + default: + dp[i][j] = dp[i-1][j-1] && pattern[i-1] == s[j-1] + } + } + } + return dp[m][n] +} + +func escapePgloaderQuote(name string) string { + return strings.ReplaceAll(name, "'", "''") +} diff --git a/internal/import/d1/pgloader_test.go b/internal/import/d1/pgloader_test.go new file mode 100644 index 000000000..090ff7157 --- /dev/null +++ b/internal/import/d1/pgloader_test.go @@ -0,0 +1,329 @@ +package d1 + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +const pgloaderOrganizationsOK = ` +2026-06-29T17:07:37.780572-04:00 LOG report summary reset + table name errors rows bytes total time +----------------------- --------- --------- --------- -------------- + fetch 0 0 0.000s + fetch meta data 0 1 0.021s + Truncate 0 1 0.053s + organizations 0 28 2.7 kB 0.775s + Total import time ✓ 28 2.7 kB 1.770s +` + +const pgloaderTeamMembersZero = ` +2026-06-29T17:10:50.659297-04:00 LOG report summary reset + table name errors rows bytes total time +----------------------- --------- --------- --------- -------------- + fetch 0 0 0.000s + fetch meta data 0 0 0.012s + Total import time ✓ 0 0.966s +` + +func TestBuildPgloaderScriptDataOnlyPerTable(t *testing.T) { + table := TableSchema{ + Name: "organizations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + {Name: "slug", Type: "TEXT", NotNull: true}, + {Name: "is_active", Type: "INTEGER", NotNull: true}, + {Name: "created_at", Type: "TEXT", NotNull: true}, + }, + } + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "organizations", + resetSequences: false, + profile: pgloaderProfileForTable(0), + }, []TableSchema{table}, []TableSchema{table}, nil) + + checks := []string{ + "WITH data only, create no tables, create no indexes, truncate, disable triggers,", + "reset no sequences,", + "workers = 8, concurrency = 2,", + "batch rows = 25000,", + "batch size = 20 MB,", + "prefetch rows = 25000", + "INCLUDING ONLY TABLE NAMES LIKE 'organizations'", + "SET work_mem to '256MB'", + "synchronous_commit to 'off'", + } + for _, want := range checks { + if !strings.Contains(script, want) { + t.Fatalf("script missing %q\n%s", want, script) + } + } + for _, bad := range []string{ + "column organizations.id to boolean", + "column organizations.is_active to boolean", + "column organizations.slug to timestamptz", + "type integer to boolean", + "type text to timestamptz", + } { + if strings.Contains(script, bad) { + t.Fatalf("script should not contain %q\n%s", bad, script) + } + } +} + +func TestBuildPgloaderScriptLargeTableProfile(t *testing.T) { + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "attachments", + resetSequences: true, + profile: pgloaderProfileForTable(pgloaderLargeTableRowThreshold), + }, nil, nil, nil) + + for _, want := range []string{ + "workers = 2, concurrency = 1,", + "batch rows = 10000,", + "prefetch rows = 5000", + } { + if !strings.Contains(script, want) { + t.Fatalf("script missing %q\n%s", want, script) + } + } +} + +func TestBuildPgloaderScriptUUIDCast(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatalf("ParseDump: %v", err) + } + ctx, err := BuildTypeCoercionContext(testFixture(t), tables) + if err != nil { + t.Fatalf("BuildTypeCoercionContext: %v", err) + } + var entityLinks *TableSchema + for i := range tables { + if tables[i].Name == "entity_links" { + entityLinks = &tables[i] + break + } + } + if entityLinks == nil { + t.Fatal("expected entity_links table") + } + + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "entity_links", + profile: pgloaderProfileForTable(0), + }, []TableSchema{*entityLinks}, tables, ctx) + + for _, want := range []string{ + "column entity_links.entity_id to uuid using sqlite-text-to-uuid", + "column entity_links.linked_at to timestamptz using sqlite-timestamp-to-timestamp", + } { + if !strings.Contains(script, want) { + t.Fatalf("script missing %q\n%s", want, script) + } + } + + var externalEntities *TableSchema + for i := range tables { + if tables[i].Name == "external_entities" { + externalEntities = &tables[i] + break + } + } + if externalEntities == nil { + t.Fatal("expected external_entities table") + } + script = buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "external_entities", + profile: pgloaderProfileForTable(0), + }, []TableSchema{*externalEntities}, tables, ctx) + if !strings.Contains(script, "column external_entities.id to uuid using sqlite-text-to-uuid") { + t.Fatalf("script missing external_entities UUID cast\n%s", script) + } +} + +func TestBuildPgloaderScriptFullLoadResetsSequences(t *testing.T) { + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + resetSequences: true, + profile: pgloaderProfileForTable(0), + }, nil, nil, nil) + if !strings.Contains(script, "reset sequences,") { + t.Fatalf("expected reset sequences in final table script:\n%s", script) + } + if strings.Contains(script, "INCLUDING ONLY") { + t.Fatalf("did not expect table filter for full load:\n%s", script) + } +} + +func TestPgloaderLoadTablesSkipsORMMetadata(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "dump.sql") + if err := os.WriteFile(path, []byte(` +CREATE TABLE organizations (id INTEGER PRIMARY KEY); +CREATE TABLE __drizzle_migrations (id INTEGER PRIMARY KEY); +CREATE TABLE users (id INTEGER PRIMARY KEY, org_id INTEGER); +`), 0o600); err != nil { + t.Fatal(err) + } + + tables, err := PgloaderLoadTables(path) + if err != nil { + t.Fatalf("PgloaderLoadTables: %v", err) + } + if len(tables) != 2 { + t.Fatalf("tables = %v, want [organizations users]", tables) + } + if tables[0] != "organizations" || tables[1] != "users" { + t.Fatalf("load order = %v", tables) + } +} + +func TestPgloaderTableNameFilterExactMatch(t *testing.T) { + got := pgloaderTableNameFilter("entity_links", nil) + want := ` LIKE 'entity_links'` + if got != want { + t.Fatalf("pgloaderTableNameFilter() = %q, want %q", got, want) + } + got = pgloaderTableNameFilter("100%done", nil) + if got != ` LIKE '100%done'` { + t.Fatalf("pgloaderTableNameFilter() = %q", got) + } + all := []string{"tbl_a", "tbl1a", "users"} + got = pgloaderTableNameFilter("tbl_a", all) + if !strings.Contains(got, ` LIKE 'tbl_a'`) { + t.Fatalf("pgloaderTableNameFilter() = %q", got) + } + if !strings.Contains(got, `EXCLUDING TABLE NAMES LIKE 'tbl1a'`) { + t.Fatalf("expected false-positive exclusion, got %q", got) + } + if strings.Contains(got, `EXCLUDING TABLE NAMES LIKE 'users'`) { + t.Fatalf("did not expect users excluded, got %q", got) + } + got = pgloaderTableNameFilter("O'Brien", nil) + if got != ` LIKE 'O''Brien'` { + t.Fatalf("pgloaderTableNameFilter() = %q", got) + } +} + +func TestSQLLikeMatch(t *testing.T) { + if !sqlLikeMatch("tbl_a", "tbl1a") { + t.Fatal("expected tbl_a pattern to match tbl1a") + } + if sqlLikeMatch("tbl_a", "users") { + t.Fatal("expected tbl_a pattern not to match users") + } + if sqlLikeMatch("user_data", "users_data") { + t.Fatal("expected user_data pattern not to match users_data") + } +} + +func TestPgloaderHadErrors(t *testing.T) { + tests := []struct { + name string + output string + want bool + }{ + { + name: "clean summary", + output: ` +| errors | rows | bytes | total time +| 0 | 100 | 1 kB | 1.000 s +`, + want: false, + }, + { + name: "summary with errors", + output: ` +| errors | rows | bytes | total time +| 3 | 97 | 1 kB | 1.000 s +`, + want: true, + }, + { + name: "database error", + output: "Database error 42501: must be owner of table users", + want: true, + }, + { + name: "insufficient privilege", + output: "INSUFFICIENT-PRIVILEGE disable triggers", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := pgloaderHadErrors(tt.output); got != tt.want { + t.Fatalf("pgloaderHadErrors() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPgloaderFetchMetaDataTableCount(t *testing.T) { + if got := pgloaderFetchMetaDataTableCount(pgloaderOrganizationsOK); got != 1 { + t.Fatalf("organizations meta = %d, want 1", got) + } + if got := pgloaderFetchMetaDataTableCount(pgloaderTeamMembersZero); got != 0 { + t.Fatalf("team_members meta = %d, want 0", got) + } + if got := pgloaderFetchMetaDataTableCount("no summary"); got != -1 { + t.Fatalf("missing meta = %d, want -1", got) + } +} + +func TestPgloaderRowsCopied(t *testing.T) { + rows, ok := pgloaderRowsCopied(pgloaderOrganizationsOK, "organizations") + if !ok || rows != 28 { + t.Fatalf("organizations rows = (%d, %v), want (28, true)", rows, ok) + } + rows, ok = pgloaderRowsCopied(pgloaderTeamMembersZero, "team_members") + if ok || rows != 0 { + t.Fatalf("team_members rows = (%d, %v), want (0, false)", rows, ok) + } +} + +func TestValidatePgloaderTableLoad(t *testing.T) { + if err := validatePgloaderTableLoad(pgloaderOrganizationsOK, "organizations", 28); err != nil { + t.Fatalf("expected ok load: %v", err) + } + if err := validatePgloaderTableLoad(pgloaderOrganizationsOK, "organizations", 0); err == nil { + t.Fatal("expected error when staged SQLite row count does not match pgloader output") + } + if err := validatePgloaderTableLoad(pgloaderOrganizationsOK, "organizations", 30); err == nil { + t.Fatal("expected error for row count mismatch") + } + if err := validatePgloaderTableLoad(pgloaderTeamMembersZero, "team_members", 700); err == nil { + t.Fatal("expected error for 0-row load") + } else if me, ok := err.(*MigrationError); !ok || me.Info.Code != ErrCodeImportFailed { + t.Fatalf("error = %#v", err) + } + if err := validatePgloaderTableLoad(pgloaderTeamMembersZero, "team_members", 0); err == nil { + t.Fatal("expected error when pgloader matched 0 source tables") + } +} + +func TestConvertSchemaPartsSplitsIndexes(t *testing.T) { + parts, count, err := ConvertSchemaParts(testFixture(t)) + if err != nil { + t.Fatalf("ConvertSchemaParts: %v", err) + } + if count != 4 { + t.Fatalf("expected 4 tables, got %d", count) + } + if !strings.Contains(parts.Tables, `CREATE TABLE IF NOT EXISTS "users"`) { + t.Fatalf("expected users table DDL") + } + if strings.Contains(parts.Tables, "CREATE INDEX") { + t.Fatalf("tables section should not contain indexes") + } + if !strings.Contains(parts.Indexes, `CREATE INDEX IF NOT EXISTS "idx_users_email"`) { + t.Fatalf("expected index DDL in indexes section:\n%s", parts.Indexes) + } +} diff --git a/internal/import/d1/pgloader_transforms.lisp b/internal/import/d1/pgloader_transforms.lisp new file mode 100644 index 000000000..9be1c03d2 --- /dev/null +++ b/internal/import/d1/pgloader_transforms.lisp @@ -0,0 +1,32 @@ +(in-package #:pgloader.transforms) + +(defun sqlite-int-to-boolean (val) + "SQLite stores booleans as INTEGER 0/1; PostgreSQL COPY expects boolean." + (cond + ((null val) :null) + ((and (integerp val) (zerop val)) "false") + ((and (integerp val) (= val 1)) "true") + ((and (stringp val) (string= val "0")) "false") + ((and (stringp val) (string= val "1")) "true") + (t :null))) + +(defun sqlite-text-to-jsonb (val) + "SQLite JSON lives in TEXT; pass valid JSON through to PostgreSQL JSONB." + (cond + ((null val) :null) + ((stringp val) val) + (t (format nil "~a" val)))) + +(defun sqlite-timestamp-to-timestamp (val) + "SQLite timestamps in TEXT/DATETIME columns; pass through for PostgreSQL TIMESTAMPTZ." + (cond + ((null val) :null) + ((stringp val) val) + (t (format nil "~a" val)))) + +(defun sqlite-text-to-uuid (val) + "SQLite UUID keys live in TEXT; pass through for PostgreSQL UUID." + (cond + ((null val) :null) + ((stringp val) val) + (t (format nil "~a" val)))) diff --git a/internal/import/d1/plan.go b/internal/import/d1/plan.go new file mode 100644 index 000000000..e1cc702e1 --- /dev/null +++ b/internal/import/d1/plan.go @@ -0,0 +1,320 @@ +package d1 + +import ( + "fmt" + "regexp" + "slices" + "sort" + "strings" + + gonanoid "github.com/matoous/go-nanoid/v2" +) + +const ( + MethodPgloader = "pgloader" + MethodPsql = "psql" // schema via psql; data via pgloader (dumps under 1GB) +) + +// PlanOptions configures migration planning. +type PlanOptions struct { + InputPath string + Org string + Database string + Branch string + Method string + MigrationID string // optional: reuse an existing migration ID from plan/start + Lint *LintResult // optional: skip re-lint when already computed +} + +// Plan builds a migration plan from a SQLite dump. +func Plan(opts PlanOptions) (*PlanResult, error) { + tables, err := ParseDump(opts.InputPath) + if err != nil { + return nil, err + } + + lintResult := opts.Lint + if lintResult == nil { + lintResult, err = Lint(opts.InputPath) + if err != nil { + return nil, err + } + } + + rowCounts, err := CountInsertRows(opts.InputPath) + if err != nil { + return nil, err + } + size, err := FileSize(opts.InputPath) + if err != nil { + return nil, err + } + + method := opts.Method + if method == "" { + method = recommendMethod(size) + } + + plan := &PlanResult{ + MigrationID: opts.planMigrationID(), + InputPath: opts.InputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + RecommendedMethod: method, + EstimatedSizeBytes: size, + Tables: make([]TablePlan, 0, len(tables)), + CastRules: defaultCastRules(), + LoadOrder: topologicalLoadOrder(tables), + Issues: lintResult.Issues, + } + + for _, table := range tables { + tp := TablePlan{ + Name: table.Name, + RowEstimate: rowCounts[table.Name], + } + for _, col := range table.Columns { + if col.ForeignKey != "" { + tp.HasFK = true + break + } + } + if !tp.HasFK { + for _, ref := range parseTableFKReferences(table.RawDDL) { + if ref != "" { + tp.HasFK = true + break + } + } + } + plan.Tables = append(plan.Tables, tp) + } + + return plan, nil +} + +func (opts PlanOptions) planMigrationID() string { + if opts.MigrationID != "" { + return opts.MigrationID + } + return gonanoid.MustGenerate("0123456789abcdefghijklmnopqrstuvwxyz", 12) +} + +func recommendMethod(sizeBytes int64) string { + const oneGB = 1024 * 1024 * 1024 + if sizeBytes > 0 && sizeBytes < oneGB { + return MethodPsql + } + return MethodPgloader +} + +func defaultCastRules() []CastRule { + return []CastRule{ + {SourceType: "integer", TargetType: "boolean", Using: "(= 1)", Tables: "match-columns-like '%active%'"}, + {SourceType: "text", TargetType: "timestamptz", Using: "sqlite-timestamp-to-timestamp"}, + {SourceType: "text", TargetType: "jsonb", Using: "sqlite-text-to-jsonb"}, + } +} + +func topologicalLoadOrder(tables []TableSchema) []string { + names := make([]string, 0, len(tables)) + nameSet := make(map[string]bool) + for _, t := range tables { + names = append(names, t.Name) + nameSet[t.Name] = true + } + + deps := make(map[string][]string) + for _, t := range tables { + for _, col := range t.Columns { + if ref := parseFKReference(col.ForeignKey); ref != "" && nameSet[ref] && !slices.Contains(deps[t.Name], ref) { + deps[t.Name] = append(deps[t.Name], ref) + } + } + for _, ref := range parseTableFKReferences(t.RawDDL) { + if nameSet[ref] && !slices.Contains(deps[t.Name], ref) { + deps[t.Name] = append(deps[t.Name], ref) + } + } + } + + sort.Strings(names) + + visited := make(map[string]bool) + var order []string + + var visit func(string) + visit = func(name string) { + if visited[name] { + return + } + visited[name] = true + for _, dep := range deps[name] { + if dep != "" { + visit(dep) + } + } + order = append(order, name) + } + + for _, name := range names { + visit(name) + } + + return order +} + +func parseFKReference(fk string) string { + if fk == "" { + return "" + } + idx := indexOfIgnoreCase(fk, "REFERENCES") + if idx < 0 { + return "" + } + rest := strings.TrimSpace(fk[idx+len("REFERENCES"):]) + parts := strings.Fields(rest) + if len(parts) == 0 { + return "" + } + ref := strings.Trim(parts[0], "`\"'") + if paren := strings.Index(ref, "("); paren >= 0 { + ref = ref[:paren] + } + return ref +} + +var tableFKRe = regexp.MustCompile(`(?i)FOREIGN\s+KEY[^)]*\)\s*REFERENCES\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) + +func parseTableFKReferences(ddl string) []string { + matches := tableFKRe.FindAllStringSubmatch(ddl, -1) + var refs []string + for _, m := range matches { + ref := firstNonEmpty(m[1], m[2], m[3], m[4]) + if ref != "" { + refs = append(refs, ref) + } + } + return refs +} + +func indexOfIgnoreCase(s, sub string) int { + return strings.Index(strings.ToUpper(s), strings.ToUpper(sub)) +} + +// SavePlan persists plan state for later import/verify. +func SavePlan(plan *PlanResult) error { + state := &MigrationState{ + MigrationID: plan.MigrationID, + Org: plan.Org, + Database: plan.Database, + Branch: plan.Branch, + InputPath: plan.InputPath, + Method: plan.RecommendedMethod, + Phase: PhasePlanned, + } + return SaveState(state) +} + +// StartNextSteps returns agent next steps after start or start --dry-run. +func StartNextSteps(migrationID, database, branch, method, inputPath string, dryRun bool) []NextStep { + target := CLICommandTarget(database, branch) + if dryRun { + cmd := fmt.Sprintf("pscale import d1 start %s --migration-id %s", target, migrationID) + if inputPath != "" { + cmd += fmt.Sprintf(" --input %q", inputPath) + } + if method != "" { + cmd += fmt.Sprintf(" --method %s", method) + } + return []NextStep{ + { + Command: cmd, + Reason: "Run the import after preview", + }, + } + } + verifyCmd := fmt.Sprintf("pscale import d1 verify %s --migration-id %s", target, migrationID) + if inputPath != "" { + verifyCmd += fmt.Sprintf(" --input %q", inputPath) + } + return []NextStep{ + { + Command: verifyCmd, + Reason: "Verify row counts, sequences, and content after import", + }, + } +} + +// StatusNextSteps returns the recommended next command for the current migration phase. +func StatusNextSteps(state *MigrationState) []NextStep { + if state == nil { + return nil + } + + target := CLICommandTarget(state.Database, state.Branch) + migrationID := state.MigrationID + + switch state.Phase { + case PhasePlanned: + cmd := fmt.Sprintf("pscale import d1 start %s --migration-id %s", target, migrationID) + if state.InputPath != "" { + cmd += fmt.Sprintf(" --input %q", state.InputPath) + } + if state.Method != "" { + cmd += fmt.Sprintf(" --method %s", state.Method) + } + return []NextStep{{ + Command: cmd, + Reason: "Run the import after dry-run preview", + }} + case PhaseImporting: + return []NextStep{{ + Command: fmt.Sprintf("pscale import d1 status %s --migration-id %s", target, migrationID), + Reason: "Import in progress; check status again when it finishes", + }} + case PhaseImported: + cmd := fmt.Sprintf("pscale import d1 verify %s --migration-id %s", target, migrationID) + if state.InputPath != "" { + cmd += fmt.Sprintf(" --input %q", state.InputPath) + } + return []NextStep{{ + Command: cmd, + Reason: "Verify row counts and content after import", + }} + case PhaseVerified: + return []NextStep{{ + Command: fmt.Sprintf("pscale import d1 complete %s --migration-id %s", target, migrationID), + Reason: "Mark migration complete after successful verify", + }} + case PhaseFailed: + cmd := fmt.Sprintf("pscale import d1 start %s --migration-id %s", target, migrationID) + if state.InputPath != "" { + cmd += fmt.Sprintf(" --input %q", state.InputPath) + } + return []NextStep{{ + Command: cmd, + Reason: "Retry or resume the failed import", + }} + default: + return nil + } +} + +// VerifyNextSteps returns next steps after a successful verify. +func VerifyNextSteps(migrationID, database, branch string) []NextStep { + target := CLICommandTarget(database, branch) + return []NextStep{{ + Command: fmt.Sprintf("pscale import d1 complete %s --migration-id %s", target, migrationID), + Reason: "Mark migration complete after successful verify", + }} +} + +// CLICommandTarget formats database and branch for pscale import d1 command examples. +func CLICommandTarget(database, branch string) string { + if branch == "" || branch == "main" { + return database + } + return database + " " + branch +} diff --git a/internal/import/d1/plan_cli_test.go b/internal/import/d1/plan_cli_test.go new file mode 100644 index 000000000..df43c339b --- /dev/null +++ b/internal/import/d1/plan_cli_test.go @@ -0,0 +1,63 @@ +package d1 + +import "testing" + +func TestCLICommandTarget(t *testing.T) { + if got := CLICommandTarget("mydb", "main"); got != "mydb" { + t.Fatalf("got %q, want mydb", got) + } + if got := CLICommandTarget("mydb", "dev"); got != "mydb dev" { + t.Fatalf("got %q, want mydb dev", got) + } +} + +func TestStartNextStepsUsesPositionalTarget(t *testing.T) { + steps := StartNextSteps("abc123", "mydb", "dev", "pgloader", "./d1-export.sql", false) + if len(steps) != 1 { + t.Fatalf("steps = %d, want 1", len(steps)) + } + want := `pscale import d1 verify mydb dev --migration-id abc123 --input "./d1-export.sql"` + if steps[0].Command != want { + t.Fatalf("command = %q, want %q", steps[0].Command, want) + } +} + +func TestStartNextStepsDryRunOmitsForce(t *testing.T) { + steps := StartNextSteps("abc123", "mydb", "main", "pgloader", "./d1-export.sql", true) + if len(steps) != 1 { + t.Fatalf("steps = %d, want 1", len(steps)) + } + want := `pscale import d1 start mydb --migration-id abc123 --input "./d1-export.sql" --method pgloader` + if steps[0].Command != want { + t.Fatalf("command = %q, want %q", steps[0].Command, want) + } +} + +func TestStatusNextStepsImported(t *testing.T) { + steps := StatusNextSteps(&MigrationState{ + MigrationID: "abc123", + Database: "import-9gb", + Branch: "main", + InputPath: "/tmp/export.sql", + Phase: PhaseImported, + }) + if len(steps) != 1 { + t.Fatalf("steps = %d, want 1", len(steps)) + } + want := `pscale import d1 verify import-9gb --migration-id abc123 --input "/tmp/export.sql"` + if steps[0].Command != want { + t.Fatalf("command = %q, want %q", steps[0].Command, want) + } +} + +func TestStatusNextStepsCompleteHasNoSteps(t *testing.T) { + steps := StatusNextSteps(&MigrationState{ + MigrationID: "abc123", + Database: "import-9gb", + Branch: "main", + Phase: PhaseComplete, + }) + if len(steps) != 0 { + t.Fatalf("steps = %d, want 0", len(steps)) + } +} diff --git a/internal/import/d1/postgres.go b/internal/import/d1/postgres.go new file mode 100644 index 000000000..6edc08bbf --- /dev/null +++ b/internal/import/d1/postgres.go @@ -0,0 +1,12 @@ +package d1 + +import ( + "database/sql" + + "github.com/planetscale/cli/internal/postgres" +) + +// OpenPostgres opens a PostgreSQL connection. +func OpenPostgres(uri string) (*sql.DB, error) { + return postgres.OpenConnection(uri) +} diff --git a/internal/import/d1/prepare.go b/internal/import/d1/prepare.go new file mode 100644 index 000000000..dcdb96fc1 --- /dev/null +++ b/internal/import/d1/prepare.go @@ -0,0 +1,196 @@ +package d1 + +import ( + "github.com/planetscale/cli/internal/printer" +) + +// ImportPrepareResult is lint + plan output used before and during import. +type ImportPrepareResult struct { + MigrationID string `json:"migration_id"` + Method string `json:"method"` + Lint *LintResult `json:"lint"` + Plan *PlanResult `json:"plan"` + CanProceed bool `json:"can_proceed"` + BlockedReason string `json:"blocked_reason,omitempty"` +} + +// PrepareImport runs lint and resolves or creates a migration plan without touching Postgres. +func PrepareImport(opts ImportOptions) (*ImportPrepareResult, error) { + inputPath, err := NormalizeInputPath(opts.InputPath) + if err != nil { + return nil, err + } + opts.InputPath = inputPath + + if opts.MigrationID != "" { + if _, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err != nil { + return nil, err + } + } + + lintResult, err := Lint(opts.InputPath) + if err != nil { + return nil, err + } + + method := opts.Method + if method == "" { + size, err := FileSize(opts.InputPath) + if err != nil { + return nil, err + } + method = recommendMethod(size) + } + + plan, err := resolvePlan(opts, method, lintResult) + if err != nil { + return nil, err + } + + if opts.Method != "" { + plan.RecommendedMethod = opts.Method + } + method = plan.RecommendedMethod + + out := &ImportPrepareResult{ + MigrationID: plan.MigrationID, + Method: method, + Lint: lintResult, + Plan: plan, + CanProceed: lintResult.ErrorCount == 0, + } + if !out.CanProceed { + out.BlockedReason = lintBlockedReason(lintResult.ErrorCount) + } + return out, nil +} + +func resolvePlan(opts ImportOptions, method string, lint *LintResult) (*PlanResult, error) { + if opts.MigrationID == "" { + return createAndSavePlan(PlanOptions{ + InputPath: opts.InputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Method: method, + Lint: lint, + }) + } + + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return nil, err + } + + if opts.InputPath != "" && state.InputPath != "" { + if err := validateInputPathAgainstState(opts.InputPath, state.InputPath); err != nil { + return nil, err + } + } + + inputPath := opts.InputPath + if inputPath == "" { + inputPath = state.InputPath + } + + plan, err := Plan(PlanOptions{ + InputPath: inputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Method: method, + MigrationID: state.MigrationID, + Lint: lint, + }) + if err != nil { + return nil, err + } + if state.Method != "" { + plan.RecommendedMethod = state.Method + } + return plan, nil +} + +func createAndSavePlan(opts PlanOptions) (*PlanResult, error) { + plan, err := Plan(opts) + if err != nil { + return nil, err + } + if err := SavePlan(plan); err != nil { + return nil, err + } + return plan, nil +} + +func importResultFromPrepare(prepared *ImportPrepareResult, dryRun bool) *ImportResult { + return &ImportResult{ + MigrationID: prepared.MigrationID, + Method: prepared.Method, + DryRun: dryRun, + Lint: prepared.Lint, + Plan: prepared.Plan, + CanProceed: prepared.CanProceed, + } +} + +// BlockedStartResponse builds the start error envelope when lint blocks import. +func BlockedStartResponse(prepared *ImportPrepareResult, dryRun bool) Response { + resp := ErrorResponse("start", ErrLintBlocked(prepared.BlockedReason)) + if prepared.Lint != nil { + resp.Issues = prepared.Lint.Issues + } + resp.Data = ImportResult{ + MigrationID: prepared.MigrationID, + Method: prepared.Method, + DryRun: dryRun, + Lint: prepared.Lint, + Plan: prepared.Plan, + CanProceed: false, + } + resp.MigrationID = prepared.MigrationID + return resp +} + +// PrintStartPreview writes a human-readable lint/plan summary before import confirmation. +func PrintStartPreview(p *printer.Printer, prepared *ImportPrepareResult) { + if prepared == nil { + return + } + p.Println("\nImport preview") + if prepared.Lint != nil { + p.Printf(" Lint: %d error(s), %d warning(s)\n", prepared.Lint.ErrorCount, prepared.Lint.WarningCount) + for _, issue := range prepared.Lint.Issues { + if issue.Severity != SeverityError && issue.Severity != SeverityWarning { + continue + } + loc := issue.Table + if issue.Column != "" { + loc += "." + issue.Column + } + if loc != "" { + loc = " " + loc + } + p.Printf(" [%s] %s%s: %s\n", issue.Severity, issue.Code, loc, previewMessage(issue)) + } + } + if prepared.Plan != nil { + sizeMB := float64(prepared.Plan.EstimatedSizeBytes) / (1024 * 1024) + p.Printf(" Plan: migration_id %s, method %s, %.1f MB, %d tables\n", + prepared.Plan.MigrationID, + prepared.Plan.RecommendedMethod, + sizeMB, + len(prepared.Plan.Tables), + ) + } + if prepared.BlockedReason != "" { + p.Printf(" Blocked: %s\n", prepared.BlockedReason) + } + p.Println() +} + +func previewMessage(issue Issue) string { + if issue.Message != "" { + return issue.Message + } + return issue.Remediation +} diff --git a/internal/import/d1/prepare_test.go b/internal/import/d1/prepare_test.go new file mode 100644 index 000000000..e0d447f13 --- /dev/null +++ b/internal/import/d1/prepare_test.go @@ -0,0 +1,167 @@ +package d1 + +import ( + "bytes" + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/planetscale/cli/internal/printer" +) + +func TestPrepareImport(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + prepared, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + }) + if err != nil { + t.Fatalf("PrepareImport: %v", err) + } + if !prepared.CanProceed { + t.Fatalf("expected can proceed, blocked: %s", prepared.BlockedReason) + } + if prepared.MigrationID == "" { + t.Fatal("expected migration id") + } + if prepared.Lint == nil || prepared.Plan == nil { + t.Fatal("expected lint and plan in prepare result") + } + if prepared.Method != prepared.Plan.RecommendedMethod { + t.Fatalf("method mismatch: %q vs %q", prepared.Method, prepared.Plan.RecommendedMethod) + } +} + +func TestImport_BlocksOnLintErrors(t *testing.T) { + prepared := &ImportPrepareResult{ + MigrationID: "mig-test", + Method: MethodPgloader, + CanProceed: false, + BlockedReason: "lint reported 1 error(s); fix or use import d1 lint for details", + Lint: &LintResult{ + ErrorCount: 1, + Issues: []Issue{{ + Code: "TEST", + Severity: SeverityError, + Message: "blocked for test", + }}, + }, + } + + result, err := Import(context.Background(), nil, nil, ImportOptions{DryRun: true}, prepared) + if err == nil { + t.Fatal("expected lint blocked error") + } + requireMigrationErr(t, err, ErrCodeLintBlocked) + if result == nil || result.CanProceed { + t.Fatal("expected result with can_proceed false") + } +} + +func TestPrepareImportRejectsMissingMigrationState(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + _, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + MigrationID: "missing-migration-id", + }) + requireMigrationErr(t, err, ErrCodeNotFound) +} + +func TestPrepareImportAcceptsEquivalentInputPath(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + absFixture := testFixture(t) + + prepared, err := PrepareImport(ImportOptions{ + InputPath: absFixture, + Org: org, + Database: database, + Branch: branch, + }) + if err != nil { + t.Fatalf("initial PrepareImport: %v", err) + } + + dir := filepath.Dir(absFixture) + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd: %v", err) + } + if err := os.Chdir(dir); err != nil { + t.Fatalf("Chdir: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(origDir) }) + + again, err := PrepareImport(ImportOptions{ + InputPath: "./" + filepath.Base(absFixture), + Org: org, + Database: database, + Branch: branch, + MigrationID: prepared.MigrationID, + }) + if err != nil { + t.Fatalf("PrepareImport with relative path: %v", err) + } + if again.MigrationID != prepared.MigrationID { + t.Fatalf("migration = %q, want %q", again.MigrationID, prepared.MigrationID) + } +} + +func TestPrepareImportRejectsCorruptMigrationState(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + store, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + migrationID := "corrupt-migration-id" + path := store.statePath("acme", "mydb", "main", migrationID) + if err := os.WriteFile(path, []byte("{not-json"), 0o600); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Remove(path) }) + + _, err = PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + MigrationID: migrationID, + }) + if err == nil { + t.Fatal("expected corrupt migration state to fail") + } +} + +func TestPrintStartPreview(t *testing.T) { + prepared, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + }) + if err != nil { + t.Fatalf("PrepareImport: %v", err) + } + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintStartPreview(p, prepared) + out := buf.String() + for _, want := range []string{"Import preview", "Lint:", "Plan:", prepared.MigrationID} { + if !strings.Contains(out, want) { + t.Fatalf("preview missing %q:\n%s", want, out) + } + } +} diff --git a/internal/import/d1/progress.go b/internal/import/d1/progress.go new file mode 100644 index 000000000..0b4d793b6 --- /dev/null +++ b/internal/import/d1/progress.go @@ -0,0 +1,93 @@ +package d1 + +import "fmt" + +// Import stage names for progress reporting. +const ( + ImportStageConnecting = "connecting" + ImportStageSQLiteStaging = "sqlite_staging" + ImportStageSchema = "schema" + ImportStagePgloader = "pgloader" + ImportStageIndexes = "indexes" + ImportStageSequences = "sequences" +) + +// Verify stage names for progress reporting. +const ( + VerifyStageRowCounts = "row_counts" + VerifyStageSequences = "verify_sequences" + VerifyStageBoolean = "boolean_columns" + VerifyStageFingerprints = "fingerprints" + VerifyStageSampleRows = "sample_rows" +) + +// ImportProgress describes import or verify pipeline progress for CLI and agent feedback. +type ImportProgress struct { + Stage string `json:"stage"` + Current int `json:"current,omitempty"` + Total int `json:"total,omitempty"` + Detail string `json:"detail,omitempty"` +} + +// ImportProgressFunc receives progress updates during import or verify. +type ImportProgressFunc func(ImportProgress) + +// FormatProgressMessage returns a human-readable progress line for CLI and Slack. +func FormatProgressMessage(p ImportProgress) string { + switch p.Stage { + case ImportStageConnecting: + return "Connecting to PlanetScale Postgres..." + case ImportStageSQLiteStaging: + return "Staging SQLite database from export..." + case ImportStageSchema: + return "Applying PostgreSQL schema..." + case ImportStagePgloader: + if p.Total > 0 && p.Detail != "" { + return fmt.Sprintf("Loading table %d/%d: %s", p.Current, p.Total, p.Detail) + } + if p.Detail != "" { + return fmt.Sprintf("Loading table %s", p.Detail) + } + return "Loading tables with pgloader..." + case ImportStageIndexes: + return "Building indexes..." + case ImportStageSequences: + return "Resetting identity sequences..." + case VerifyStageRowCounts: + if p.Total > 0 && p.Detail != "" { + return fmt.Sprintf("Counting rows %d/%d: %s", p.Current, p.Total, p.Detail) + } + return "Comparing row counts..." + case VerifyStageSequences: + return "Checking identity sequences..." + case VerifyStageBoolean: + return "Checking boolean column coercion..." + case VerifyStageFingerprints: + return "Checking table fingerprints..." + case VerifyStageSampleRows: + return "Sampling row content..." + default: + if p.Detail != "" { + return p.Detail + } + return "Working..." + } +} + +func (opts ImportOptions) reportProgress(p ImportProgress) { + if opts.OnProgress != nil { + opts.OnProgress(p) + } + if opts.MigrationID != "" { + notifyImportProgress(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, opts.notifyBase, p) + } +} + +func (opts VerifyOptions) reportProgress(p ImportProgress) { + if opts.OnProgress != nil { + opts.OnProgress(p) + } + if opts.MigrationID != "" { + notifyImportProgress(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, opts.notifyBase, p) + } +} diff --git a/internal/import/d1/schema_reset.go b/internal/import/d1/schema_reset.go new file mode 100644 index 000000000..00b8be819 --- /dev/null +++ b/internal/import/d1/schema_reset.go @@ -0,0 +1,242 @@ +package d1 + +import ( + "context" + "fmt" + "strings" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/postgres" +) + +const ( + postgresRoleName = "postgres" + publicSchemaName = "public" +) + +func cleanupStaleImportRoles(ctx context.Context, psClient *ps.Client, opts ImportOptions, currentUsername string) error { + if psClient == nil || opts.DestURI != "" { + return nil + } + + roles, err := psClient.PostgresRoles.List(ctx, &ps.ListPostgresRolesRequest{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + }) + if err != nil { + return fmt.Errorf("list postgres roles: %w", err) + } + + if err := ensureDefaultPostgresRole(ctx, psClient, opts, roles); err != nil { + return err + } + + var firstErr error + for _, role := range roles { + if !isStaleImportRole(role, currentUsername) { + continue + } + err := psClient.PostgresRoles.Delete(ctx, &ps.DeletePostgresRoleRequest{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + RoleId: role.ID, + Successor: postgresRoleName, + }) + if err != nil && firstErr == nil { + firstErr = fmt.Errorf("delete stale import role %q: %w", role.Username, err) + } + } + return firstErr +} + +func ensureDefaultPostgresRole(ctx context.Context, psClient *ps.Client, opts ImportOptions, roles []*ps.PostgresRole) error { + for _, role := range roles { + if role != nil && isDefaultPostgresRole(role.Username) { + return nil + } + } + + _, err := psClient.PostgresRoles.ResetDefaultRole(ctx, &ps.ResetDefaultRoleRequest{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + }) + if err != nil { + return fmt.Errorf("ensure default postgres role: %w", err) + } + return nil +} + +func isDefaultPostgresRole(username string) bool { + return username == postgresRoleName || strings.HasPrefix(username, postgresRoleName+".") +} + +const importRoleNamePrefix = "d1-import-" + +func isImportRoleName(name string) bool { + return strings.HasPrefix(name, importRoleNamePrefix) +} + +func isStaleImportRole(role *ps.PostgresRole, currentUsername string) bool { + if role == nil || role.Username == currentUsername { + return false + } + // Only delete roles created by this import flow (d1-import-*). Do not touch + // other ephemeral API roles (shell, manual admin roles, concurrent work). + return isImportRoleName(role.Name) +} + +func usernameFromDestURI(destURI string) (string, error) { + cfg, err := postgres.ParseConnectionURI(destURI) + if err != nil { + return "", err + } + if cfg.User == "" { + return "", fmt.Errorf("destination URI missing user") + } + return cfg.User, nil +} + +func importTableNames(tables []TableSchema) []string { + names := make([]string, 0, len(tables)) + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + names = append(names, table.Name) + } + return names +} + +func existingPublicTables(ctx context.Context, destURI string, names []string) (map[string]struct{}, error) { + existing := make(map[string]struct{}) + if len(names) == 0 { + return existing, nil + } + + db, err := OpenPostgres(destURI) + if err != nil { + return nil, err + } + defer db.Close() + + placeholders := make([]string, len(names)) + args := make([]any, len(names)) + for i, name := range names { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = name + } + query := fmt.Sprintf( + `SELECT table_name FROM information_schema.tables WHERE table_schema = '%s' AND table_name IN (%s)`, + publicSchemaName, + strings.Join(placeholders, ", "), + ) + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list existing tables: %w", err) + } + defer rows.Close() + + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("scan table name: %w", err) + } + existing[name] = struct{}{} + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("list existing tables: %w", err) + } + + return existing, nil +} + +func populatedLoadedTables(ctx context.Context, destURI string, loaded []string) ([]string, error) { + if len(loaded) == 0 { + return nil, nil + } + withRows, err := destTablesWithRows(ctx, destURI, loaded) + if err != nil { + return nil, err + } + return skipLoadedTablesForResume(loaded, withRows), nil +} + +func skipLoadedTablesForResume(loaded []string, withRows map[string]struct{}) []string { + populated := make([]string, 0, len(loaded)) + for _, table := range loaded { + if _, ok := withRows[table]; ok { + populated = append(populated, table) + } + } + return populated +} + +func destTablesWithRows(ctx context.Context, destURI string, tables []string) (map[string]struct{}, error) { + out := make(map[string]struct{}) + if len(tables) == 0 { + return out, nil + } + + db, err := OpenPostgres(destURI) + if err != nil { + return nil, err + } + defer db.Close() + + for _, table := range tables { + var hasRows bool + query := fmt.Sprintf(`SELECT EXISTS (SELECT 1 FROM %s LIMIT 1)`, postgres.QuoteIdentifier(table)) + if err := db.QueryRowContext(ctx, query).Scan(&hasRows); err != nil { + return nil, fmt.Errorf("check rows in %s: %w", table, err) + } + if hasRows { + out[table] = struct{}{} + } + } + return out, nil +} + +func conflictingImportTables(importNames []string, existing map[string]struct{}) []string { + if len(existing) == 0 { + return nil + } + conflicts := make([]string, 0, len(importNames)) + for _, name := range importNames { + if _, found := existing[name]; found { + conflicts = append(conflicts, name) + } + } + return conflicts +} + +func buildImportTablesSQL(inputPath string, tables []TableSchema) (string, error) { + var coerceCtx *TypeCoercionContext + if inputPath != "" { + var err error + coerceCtx, err = BuildTypeCoercionContext(inputPath, tables) + if err != nil { + return "", err + } + } + + tableByName := make(map[string]TableSchema, len(tables)) + for _, table := range tables { + tableByName[table.Name] = table + } + + var b strings.Builder + for _, name := range topologicalLoadOrder(tables) { + table, ok := tableByName[name] + if !ok || IsORMMetadataTable(table.Name) { + continue + } + b.WriteString(convertTableDDL(table, tables, coerceCtx)) + b.WriteString("\n\n") + } + return b.String(), nil +} diff --git a/internal/import/d1/schema_reset_test.go b/internal/import/d1/schema_reset_test.go new file mode 100644 index 000000000..7522e40d3 --- /dev/null +++ b/internal/import/d1/schema_reset_test.go @@ -0,0 +1,66 @@ +package d1 + +import ( + "strings" + "testing" +) + +func TestConflictingImportTables(t *testing.T) { + existing := map[string]struct{}{ + "organizations": {}, + "posts": {}, + "other_app": {}, + } + conflicts := conflictingImportTables([]string{"organizations", "users", "posts"}, existing) + if len(conflicts) != 2 || conflicts[0] != "organizations" || conflicts[1] != "posts" { + t.Fatalf("conflicts = %v", conflicts) + } +} + +func TestErrExistingImportTables(t *testing.T) { + err := errExistingImportTables([]string{"users", "posts"}) + requireMigrationErr(t, err, ErrCodeDestinationConflict) + me, _ := migrationErr(err) + if !strings.Contains(me.Info.Message, "users, posts") { + t.Fatalf("message = %q", me.Info.Message) + } +} + +func TestBuildImportTablesSQLCreatesAllImportTables(t *testing.T) { + tables := []TableSchema{ + { + Name: "organizations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + }, + }, + { + Name: "users", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + }, + }, + } + + sql, err := buildImportTablesSQL("", tables) + if err != nil { + t.Fatalf("buildImportTablesSQL: %v", err) + } + if !strings.Contains(sql, `CREATE TABLE IF NOT EXISTS "organizations"`) { + t.Fatalf("expected organizations table DDL:\n%s", sql) + } + if !strings.Contains(sql, `CREATE TABLE IF NOT EXISTS "users"`) { + t.Fatalf("expected users table DDL:\n%s", sql) + } +} + +func TestImportTableNamesSkipsORMMetadata(t *testing.T) { + names := importTableNames([]TableSchema{ + {Name: "users"}, + {Name: "__drizzle_migrations"}, + {Name: "posts"}, + }) + if len(names) != 2 || names[0] != "users" || names[1] != "posts" { + t.Fatalf("names = %v", names) + } +} diff --git a/internal/import/d1/sqlite_load.go b/internal/import/d1/sqlite_load.go new file mode 100644 index 000000000..0c451bd2a --- /dev/null +++ b/internal/import/d1/sqlite_load.go @@ -0,0 +1,259 @@ +package d1 + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" + + execabs "golang.org/x/sys/execabs" +) + +const defaultSQLiteChunkBytes = 64 << 20 // 64 MiB of SQL per .read batch + +type sqliteSourceMeta struct { + DumpSize int64 `json:"dump_size"` + DumpModTime time.Time `json:"dump_mod_time"` +} + +// EnsureSQLiteFromDump loads dump SQL into sqlite unless a fresh-enough database already exists. +func EnsureSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + if canReuseSQLite(dumpPath, sqlitePath) { + return nil + } + return buildSQLiteFromDump(ctx, dumpPath, sqlitePath) +} + +// BuildSQLiteFromDump always rebuilds sqlite from the dump (tests and forced refresh). +func BuildSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + return buildSQLiteFromDump(ctx, dumpPath, sqlitePath) +} + +func buildSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + dumpPath, err := ValidateInputPath(dumpPath) + if err != nil { + return err + } + + sqlite3, err := FindSQLite3() + if err != nil { + return err + } + + if err := os.RemoveAll(sqlitePath); err != nil && !os.IsNotExist(err) { + return err + } + _ = os.Remove(sqliteSourceMetaPath(sqlitePath)) + + dir := filepath.Dir(sqlitePath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + + if err := loadSQLiteDumpChunked(ctx, sqlite3, dumpPath, sqlitePath, defaultSQLiteChunkBytes); err != nil { + return err + } + return writeSQLiteSourceMeta(dumpPath, sqlitePath) +} + +func sqliteSourceMetaPath(sqlitePath string) string { + return sqlitePath + ".source" +} + +func writeSQLiteSourceMeta(dumpPath, sqlitePath string) error { + dumpInfo, err := os.Stat(dumpPath) + if err != nil { + return err + } + meta := sqliteSourceMeta{ + DumpSize: dumpInfo.Size(), + DumpModTime: dumpInfo.ModTime(), + } + data, err := json.Marshal(meta) + if err != nil { + return err + } + return os.WriteFile(sqliteSourceMetaPath(sqlitePath), data, 0o600) +} + +func canReuseSQLite(dumpPath, sqlitePath string) bool { + dumpInfo, err := os.Stat(dumpPath) + if err != nil { + return false + } + sqliteInfo, err := os.Stat(sqlitePath) + if err != nil || sqliteInfo.Size() == 0 { + return false + } + if !sqliteHasTables(sqlitePath) { + return false + } + metaData, err := os.ReadFile(sqliteSourceMetaPath(sqlitePath)) + if err != nil { + return false + } + var meta sqliteSourceMeta + if err := json.Unmarshal(metaData, &meta); err != nil { + return false + } + if meta.DumpSize != dumpInfo.Size() { + return false + } + if !meta.DumpModTime.Equal(dumpInfo.ModTime()) { + return false + } + return true +} + +func sqliteHasTables(sqlitePath string) bool { + sqlite3, err := FindSQLite3() + if err != nil { + return false + } + out, err := execabs.Command(sqlite3, sqlitePath, "SELECT 1 FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' LIMIT 1;").CombinedOutput() + if err != nil { + return false + } + return strings.TrimSpace(string(out)) == "1" +} + +func loadSQLiteDumpChunked(ctx context.Context, sqlite3, dumpPath, sqlitePath string, chunkBytes int64) error { + dump, err := os.Open(dumpPath) + if err != nil { + return err + } + defer dump.Close() + + chunkDir, err := os.MkdirTemp("", "pscale-d1-sqlite-chunk-*") + if err != nil { + return err + } + defer os.RemoveAll(chunkDir) + + reader := bufio.NewReader(dump) + var ( + chunkIdx int + chunkFile *os.File + chunkPath string + chunkSize int64 + lineNo int + totalLines int + ) + + flushChunk := func() error { + if chunkFile == nil { + return nil + } + if err := chunkFile.Close(); err != nil { + return err + } + chunkFile = nil + + readPath := strings.ReplaceAll(chunkPath, "'", "''") + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, fmt.Sprintf(".read %s", readPath)) + var stderr bytes.Buffer + cmd.Stdout = io.Discard + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf( + "sqlite3 chunk %d (through line %d): %w: %s", + chunkIdx, + lineNo, + err, + truncateLoadError(stderr.String(), 2048), + ) + } + return os.Remove(chunkPath) + } + + startChunk := func() error { + chunkIdx++ + chunkPath = filepath.Join(chunkDir, fmt.Sprintf("chunk-%04d.sql", chunkIdx)) + f, err := os.OpenFile(chunkPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return err + } + chunkFile = f + chunkSize = 0 + return nil + } + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + lineNo++ + totalLines++ + if chunkFile == nil { + if err := startChunk(); err != nil { + return err + } + } + if _, werr := chunkFile.Write(line); werr != nil { + return werr + } + chunkSize += int64(len(line)) + if chunkSize >= chunkBytes && lineEndsSQLStatement(line) { + if err := flushChunk(); err != nil { + return err + } + } + } + if err == io.EOF { + break + } + if err != nil { + return err + } + } + + if err := flushChunk(); err != nil { + return fmt.Errorf("sqlite3 load failed: %w", err) + } + if totalLines == 0 { + return fmt.Errorf("sqlite3 load failed: dump is empty") + } + return nil +} + +func truncateLoadError(msg string, max int) string { + msg = strings.TrimSpace(msg) + if len(msg) <= max { + return msg + } + return msg[:max] + "..." +} + +// lineEndsSQLStatement reports whether line completes a standalone SQL statement. +// Chunk flushes use this so multi-line CREATE TABLE blocks are never split. +func lineEndsSQLStatement(line []byte) bool { + s := strings.TrimSpace(string(line)) + if s == "" || strings.HasPrefix(s, "--") { + return false + } + return sqlEndsWithSemicolon(s) +} + +func sqlEndsWithSemicolon(s string) bool { + inSingle := false + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\'' { + if inSingle && i+1 < len(s) && s[i+1] == '\'' { + i++ + continue + } + inSingle = !inSingle + continue + } + if c == ';' && !inSingle && strings.TrimSpace(s[i+1:]) == "" { + return true + } + } + return false +} diff --git a/internal/import/d1/sqlite_load_test.go b/internal/import/d1/sqlite_load_test.go new file mode 100644 index 000000000..3cee0682f --- /dev/null +++ b/internal/import/d1/sqlite_load_test.go @@ -0,0 +1,225 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestBuildSQLiteFromDump(t *testing.T) { + requireSQLite3(t) + + dir := t.TempDir() + dumpPath := filepath.Join(dir, "dump.sql") + sqlitePath := filepath.Join(dir, "load.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE attachments (\n") + b.WriteString(" id INTEGER PRIMARY KEY,\n") + b.WriteString(" payload BLOB\n") + b.WriteString(");\n") + hex := strings.Repeat("41", 48000) // ~96 KiB blob, similar to wrangler export lines + fmt.Fprintf(&b, "INSERT INTO attachments (id, payload) VALUES(1,X'%s');\n", hex) + b.WriteString("INSERT INTO attachments (id, payload) VALUES(2,NULL);\n") + + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + if err := BuildSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatalf("BuildSQLiteFromDump: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"attachments"}) + if err != nil { + t.Fatal(err) + } + if counts["attachments"] != 2 { + t.Fatalf("expected 2 rows, got %d", counts["attachments"]) + } +} + +func TestEnsureSQLiteFromDumpReusesExisting(t *testing.T) { + requireSQLite3(t) + + dir := t.TempDir() + dumpPath := filepath.Join(dir, "dump.sql") + sqlitePath := filepath.Join(dir, "load.sqlite") + + content := "PRAGMA defer_foreign_keys=TRUE;\nCREATE TABLE t (id INTEGER PRIMARY KEY);\nINSERT INTO t VALUES(1);\n" + if err := os.WriteFile(dumpPath, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + if err := BuildSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info1, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + + time.Sleep(10 * time.Millisecond) + if err := os.WriteFile(dumpPath, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + // Touch dump to be newer than sqlite — should rebuild. + dumpInfo, _ := os.Stat(dumpPath) + if err := os.Chtimes(dumpPath, dumpInfo.ModTime().Add(time.Second), dumpInfo.ModTime().Add(time.Second)); err != nil { + t.Fatal(err) + } + + if err := EnsureSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info2, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + if !info2.ModTime().After(info1.ModTime()) { + t.Fatal("expected rebuild when dump is newer than sqlite") + } + + // Unchanged dump should reuse without rebuild (meta matches dump size + mtime). + if err := EnsureSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info3, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + if info3.ModTime().After(info2.ModTime()) { + t.Fatal("expected sqlite reuse without rebuild when dump is unchanged") + } +} + +func TestCanReuseSQLiteRejectsDumpSizeMismatch(t *testing.T) { + requireSQLite3(t) + + dir := t.TempDir() + dumpPath := filepath.Join(dir, "dump.sql") + sqlitePath := filepath.Join(dir, "load.sqlite") + + content := "PRAGMA defer_foreign_keys=TRUE;\nCREATE TABLE t (id INTEGER PRIMARY KEY);\nINSERT INTO t VALUES(1);\n" + if err := os.WriteFile(dumpPath, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + if err := BuildSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + if !canReuseSQLite(dumpPath, sqlitePath) { + t.Fatal("expected reuse immediately after build") + } + + // Same mtime but different size should not reuse (stale sidecar). + if err := os.WriteFile(dumpPath, []byte(content+"INSERT INTO t VALUES(2);\n"), 0o600); err != nil { + t.Fatal(err) + } + dumpInfo, _ := os.Stat(dumpPath) + sqliteInfo, _ := os.Stat(sqlitePath) + if err := os.Chtimes(dumpPath, sqliteInfo.ModTime(), sqliteInfo.ModTime()); err != nil { + t.Fatal(err) + } + if canReuseSQLite(dumpPath, sqlitePath) { + t.Fatalf("expected no reuse when dump size changed (dump=%d meta=%d)", dumpInfo.Size(), len(content)) + } +} + +func TestLoadSQLiteDumpChunkedMultiLineCreate(t *testing.T) { + requireSQLite3(t) + + dir := t.TempDir() + dumpPath := filepath.Join(dir, "create.sql") + sqlitePath := filepath.Join(dir, "create.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE multi (\n") + for i := 0; i < 40; i++ { + fmt.Fprintf(&b, " col_%d TEXT,\n", i) + } + b.WriteString(" id INTEGER PRIMARY KEY\n") + b.WriteString(");\n") + for i := 0; i < 10; i++ { + fmt.Fprintf(&b, "INSERT INTO multi (id) VALUES(%d);\n", i) + } + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + sqlite3, err := FindSQLite3() + if err != nil { + t.Fatal(err) + } + // Small chunks force splits that would bisect CREATE TABLE without boundary-aware flushing. + if err := loadSQLiteDumpChunked(context.Background(), sqlite3, dumpPath, sqlitePath, 200); err != nil { + t.Fatalf("loadSQLiteDumpChunked: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"multi"}) + if err != nil { + t.Fatal(err) + } + if counts["multi"] != 10 { + t.Fatalf("expected 10 rows, got %d", counts["multi"]) + } +} + +func TestSQLStatementBoundary(t *testing.T) { + tests := []struct { + line string + want bool + }{ + {"CREATE TABLE t (id INTEGER);\n", true}, + {" );\n", true}, + {" payload BLOB\n", false}, + {"INSERT INTO t VALUES('a;b');\n", true}, + {"INSERT INTO t VALUES('a;b\n", false}, + {"-- comment only\n", false}, + {"PRAGMA defer_foreign_keys=TRUE;\n", true}, + } + for _, tc := range tests { + if got := lineEndsSQLStatement([]byte(tc.line)); got != tc.want { + t.Fatalf("lineEndsSQLStatement(%q) = %v, want %v", tc.line, got, tc.want) + } + } +} + +func TestLoadSQLiteDumpChunked(t *testing.T) { + requireSQLite3(t) + + dir := t.TempDir() + dumpPath := filepath.Join(dir, "multi.sql") + sqlitePath := filepath.Join(dir, "multi.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT);\n") + for i := 0; i < 200; i++ { + fmt.Fprintf(&b, "INSERT INTO t (id, v) VALUES(%d,'row');\n", i) + } + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + sqlite3, err := FindSQLite3() + if err != nil { + t.Fatal(err) + } + // Force many small chunks to exercise batching. + if err := loadSQLiteDumpChunked(context.Background(), sqlite3, dumpPath, sqlitePath, 256); err != nil { + t.Fatalf("loadSQLiteDumpChunked: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"t"}) + if err != nil { + t.Fatal(err) + } + if counts["t"] != 200 { + t.Fatalf("expected 200 rows, got %d", counts["t"]) + } +} diff --git a/internal/import/d1/state.go b/internal/import/d1/state.go new file mode 100644 index 000000000..eda7c0f3a --- /dev/null +++ b/internal/import/d1/state.go @@ -0,0 +1,187 @@ +package d1 + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/adrg/xdg" +) + +// StateStore manages local migration state. +type StateStore struct { + dir string +} + +// NewStateStore returns the default state store location. +func NewStateStore() (*StateStore, error) { + dir, err := xdg.ConfigFile("planetscale/import-d1") + if err != nil { + return nil, fmt.Errorf("state dir: %w", err) + } + if os.Getenv("PSCALE_TEST_MODE") == "1" { + dir = filepath.Join(os.TempDir(), "pscale-import-d1-test") + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, err + } + return &StateStore{dir: dir}, nil +} + +func (s *StateStore) statePath(org, database, branch, migrationID string) string { + key := fmt.Sprintf("%s_%s_%s_%s.json", sanitize(org), sanitize(database), sanitize(branch), sanitize(migrationID)) + return filepath.Join(s.dir, key) +} + +func sanitize(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { + out = append(out, c) + } else { + out = append(out, '_') + } + } + return string(out) +} + +// Save persists migration state. +func (s *StateStore) Save(state *MigrationState) error { + if state.CreatedAt.IsZero() { + state.CreatedAt = time.Now().UTC() + } + state.UpdatedAt = time.Now().UTC() + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + path := s.statePath(state.Org, state.Database, state.Branch, state.MigrationID) + return os.WriteFile(path, data, 0o600) +} + +// Load retrieves migration state by ID. +func (s *StateStore) Load(org, database, branch, migrationID string) (*MigrationState, error) { + path := s.statePath(org, database, branch, migrationID) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, newMigrationError(ErrCodeNotFound, "migration state not found", "Run `import d1 start --dry-run` or `import d1 start` to create migration state") + } + return nil, err + } + var state MigrationState + if err := json.Unmarshal(data, &state); err != nil { + return nil, err + } + return &state, nil +} + +// Delete removes migration state. +func (s *StateStore) Delete(org, database, branch, migrationID string) error { + path := s.statePath(org, database, branch, migrationID) + err := os.Remove(path) + if os.IsNotExist(err) { + return nil + } + return err +} + +// SaveState is a package-level helper using the default store. +func SaveState(state *MigrationState) error { + store, err := NewStateStore() + if err != nil { + return err + } + return store.Save(state) +} + +// LoadState loads state using the default store. +func LoadState(org, database, branch, migrationID string) (*MigrationState, error) { + store, err := NewStateStore() + if err != nil { + return nil, err + } + return store.Load(org, database, branch, migrationID) +} + +// SetMigrationPhase updates the phase on existing migration state. +func SetMigrationPhase(org, database, branch, migrationID, phase string) error { + return updateMigrationState(org, database, branch, migrationID, func(state *MigrationState) { + state.Phase = phase + }) +} + +func updateMigrationState(org, database, branch, migrationID string, update func(*MigrationState)) error { + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + return err + } + update(state) + return SaveState(state) +} + +func saveImportMigrationState(opts ImportOptions, phase, sqlitePath string) error { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + if me, ok := migrationErr(err); ok && me.Info.Code == ErrCodeNotFound { + state = &MigrationState{ + MigrationID: opts.MigrationID, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + } + } else { + return err + } + } + state.Phase = phase + if opts.InputPath != "" { + state.InputPath = opts.InputPath + } + if opts.Method != "" { + state.Method = opts.Method + } + if opts.DBName != "" { + state.DBName = opts.DBName + } + if sqlitePath != "" { + state.SQLitePath = sqlitePath + } + return SaveState(state) +} + +// Complete marks a migration as finished in local state. +func Complete(org, database, branch, migrationID string, api NotifyAPIConfig) error { + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + return err + } + if state.Phase != PhaseVerified { + return newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("migration %q is %q; verify must succeed before complete", migrationID, state.Phase), + "Run `pscale import d1 verify` before `import d1 complete`", + ) + } + + skippedTables, _, err := completeORMNextSteps(state.InputPath) + if err != nil { + return err + } + + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseComplete); err != nil { + return err + } + payload := notifyPayloadFromState(state) + payload.Message = CompleteSlackMessage(skippedTables, ormNamesFromSkippedTables(skippedTables)) + NotifyImportEventSync(api, org, database, branch, migrationID, NotifyEventComplete, payload) + return nil +} + +// Teardown is deprecated; use Complete. +func Teardown(org, database, branch, migrationID string, api NotifyAPIConfig) error { + return Complete(org, database, branch, migrationID, api) +} diff --git a/internal/import/d1/state_test.go b/internal/import/d1/state_test.go new file mode 100644 index 000000000..741839948 --- /dev/null +++ b/internal/import/d1/state_test.go @@ -0,0 +1,258 @@ +package d1 + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + ps "github.com/planetscale/planetscale-go/planetscale" +) + +func TestCompleteRequiresVerifiedPhase(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "complete-unverified" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseImported); err != nil { + t.Fatalf("SetMigrationPhase: %v", err) + } + + if err := Complete(org, database, branch, migrationID, NotifyAPIConfig{}); err == nil { + t.Fatal("expected error completing unverified migration") + } +} + +func TestMigrationPhaseTransitions(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "testphase123" + + plan := &PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + } + if err := SavePlan(plan); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if state.Phase != PhasePlanned { + t.Fatalf("phase = %q, want %q", state.Phase, PhasePlanned) + } + + opts := ImportOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + InputPath: plan.InputPath, + Method: MethodPgloader, + } + if err := saveImportMigrationState(opts, PhaseImporting, ""); err != nil { + t.Fatalf("saveImportMigrationState importing: %v", err) + } + state, err = LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState importing: %v", err) + } + if state.Phase != PhaseImporting { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseImporting) + } + + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseImported); err != nil { + t.Fatalf("SetMigrationPhase imported: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase verified: %v", err) + } + if err := Complete(org, database, branch, migrationID, NotifyAPIConfig{}); err != nil { + t.Fatalf("Complete: %v", err) + } + + state, err = LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState complete: %v", err) + } + if state.Phase != PhaseComplete { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseComplete) + } +} + +func TestComplete_SucceedsWhenNotifyAPIFails(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + client, err := ps.NewClient( + ps.WithBaseURL(srv.URL), + ps.WithAccessToken("token"), + ) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + org, database, branch := "acme", "mydb", "main" + migrationID := "notifyfail789" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase verified: %v", err) + } + + if err := Complete(org, database, branch, migrationID, NotifyAPIConfig{Client: client}); err != nil { + t.Fatalf("Complete should succeed when notify API fails: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if state.Phase != PhaseComplete { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseComplete) + } +} + +func TestComplete_SendsCompletePayload(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + var body map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal body: %v", err) + } + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + client, err := ps.NewClient( + ps.WithBaseURL(srv.URL), + ps.WithAccessToken("token"), + ) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + org, database, branch := "acme", "mydb", "main" + migrationID := "completepayload123" + inputPath := testFixture(t) + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: inputPath, + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + state.Method = MethodPgloader + state.LoadedTables = []string{"organizations", "users"} + state.CreatedAt = time.Now().UTC().Add(-5 * time.Minute) + if err := SaveState(state); err != nil { + t.Fatalf("SaveState: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase verified: %v", err) + } + + if err := Complete(org, database, branch, migrationID, NotifyAPIConfig{Client: client}); err != nil { + t.Fatalf("Complete: %v", err) + } + + if body["event"] != NotifyEventComplete { + t.Fatalf("event = %v, want %s", body["event"], NotifyEventComplete) + } + if body["method"] != MethodPgloader { + t.Fatalf("method = %v, want %s", body["method"], MethodPgloader) + } + if body["table_count"].(float64) != 2 { + t.Fatalf("table_count = %v, want 2", body["table_count"]) + } + if body["export_bytes"].(float64) <= 0 { + t.Fatalf("export_bytes = %v, want > 0", body["export_bytes"]) + } + if body["duration_ms"].(float64) <= 0 { + t.Fatalf("duration_ms = %v, want > 0", body["duration_ms"]) + } + msg, _ := body["message"].(string) + if msg == "" || !strings.Contains(msg, "re-baseline ORM migrations") { + t.Fatalf("message = %v, want ORM re-baseline guidance", body["message"]) + } +} + +func TestSaveImportMigrationStateFailed(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "testfailed456" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + opts := ImportOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + InputPath: testFixture(t), + Method: MethodPgloader, + } + if err := saveImportMigrationState(opts, PhaseFailed, "/tmp/test.sqlite"); err != nil { + t.Fatalf("saveImportMigrationState failed: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if state.Phase != PhaseFailed { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseFailed) + } + if state.SQLitePath != "/tmp/test.sqlite" { + t.Fatalf("sqlite path = %q", state.SQLitePath) + } +} diff --git a/internal/import/d1/testdata/sample_d1_export.sql b/internal/import/d1/testdata/sample_d1_export.sql new file mode 100644 index 000000000..d6fe3ade6 --- /dev/null +++ b/internal/import/d1/testdata/sample_d1_export.sql @@ -0,0 +1,71 @@ +-- Sample D1 export fixture for import d1 tests +PRAGMA foreign_keys=OFF; + +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL UNIQUE, + active INTEGER DEFAULT 1, + created_at TEXT NOT NULL +); + +CREATE TABLE posts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + title TEXT NOT NULL, + body TEXT, + published INTEGER DEFAULT 0, + metadata TEXT, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE external_entities ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE entity_links ( + entity_id TEXT NOT NULL, + post_id INTEGER NOT NULL, + linked_at TEXT NOT NULL, + PRIMARY KEY (entity_id, post_id), + FOREIGN KEY (entity_id) REFERENCES external_entities(id), + FOREIGN KEY (post_id) REFERENCES posts(id) +); + +CREATE TABLE __drizzle_migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hash TEXT NOT NULL, + created_at INTEGER +); + +CREATE TABLE _prisma_migrations ( + id TEXT PRIMARY KEY, + checksum TEXT NOT NULL, + finished_at TEXT, + migration_name TEXT NOT NULL, + logs TEXT, + rolled_back_at TEXT, + started_at TEXT NOT NULL, + applied_steps_count INTEGER NOT NULL +); + +INSERT INTO users (id, email, active, created_at) VALUES + (1, 'alice@example.com', 1, '2024-01-01T00:00:00Z'), + (2, 'bob@example.com', 0, '2024-01-02T00:00:00Z'); + +INSERT INTO posts (id, user_id, title, body, published, metadata) VALUES + (1, 1, 'Hello', 'World', 1, '{"tags":["intro"]}'), + (2, 1, 'Draft', 'Work in progress', 0, NULL); + +INSERT INTO external_entities (id, name, created_at) VALUES + ('550e8400-e29b-41d4-a716-446655440000', 'Webhook A', '2024-01-01T00:00:00Z'); + +INSERT INTO entity_links (entity_id, post_id, linked_at) VALUES + ('550e8400-e29b-41d4-a716-446655440000', 1, '2024-01-02T00:00:00Z'); + +INSERT INTO __drizzle_migrations (id, hash, created_at) VALUES + (1, 'abc123', 1700000000); + +CREATE INDEX idx_users_email ON users(email); +CREATE UNIQUE INDEX idx_entity_links_post ON entity_links(post_id); diff --git a/internal/import/d1/types.go b/internal/import/d1/types.go new file mode 100644 index 000000000..d2ebd0f86 --- /dev/null +++ b/internal/import/d1/types.go @@ -0,0 +1,189 @@ +package d1 + +import "time" + +// Severity levels for lint/plan issues. +const ( + SeverityError = "error" + SeverityWarning = "warning" + SeverityInfo = "info" +) + +// Issue describes a migration concern with agent-friendly remediation. +type Issue struct { + Code string `json:"code"` + Severity string `json:"severity"` + Table string `json:"table,omitempty"` + Column string `json:"column,omitempty"` + Message string `json:"message,omitempty"` + Remediation string `json:"remediation"` +} + +// NextStep guides agents to the next tool or command. +type NextStep struct { + Tool string `json:"tool,omitempty"` + Command string `json:"command,omitempty"` + Reason string `json:"reason"` +} + +// Response is the common JSON envelope for import d1 commands. +type Response struct { + Status string `json:"status"` + Command string `json:"command,omitempty"` + Phase string `json:"phase,omitempty"` + MigrationID string `json:"migration_id,omitempty"` + Issues []Issue `json:"issues,omitempty"` + NextSteps []NextStep `json:"next_steps,omitempty"` + Reminder string `json:"reminder,omitempty"` + Data any `json:"data,omitempty"` + Error *ErrorInfo `json:"error,omitempty"` +} + +// ErrorInfo is a structured CLI/MCP error. +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + Remediation string `json:"remediation,omitempty"` +} + +// DoctorResult lists prerequisite checks. +type DoctorResult struct { + Checks []DoctorCheck `json:"checks"` + Ready bool `json:"ready"` +} + +// DoctorCheck is a single prerequisite check. +type DoctorCheck struct { + Name string `json:"name"` + Status string `json:"status"` + Version string `json:"version,omitempty"` + Message string `json:"message,omitempty"` + Remediation string `json:"remediation,omitempty"` +} + +// LintResult summarizes lint output. +type LintResult struct { + InputPath string `json:"input_path"` + TableCount int `json:"table_count"` + ErrorCount int `json:"error_count"` + WarningCount int `json:"warning_count"` + Issues []Issue `json:"issues"` + Tables []string `json:"tables"` +} + +// PlanResult is the migration plan JSON. +type PlanResult struct { + MigrationID string `json:"migration_id"` + InputPath string `json:"input_path"` + Org string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + RecommendedMethod string `json:"recommended_method"` + EstimatedSizeBytes int64 `json:"estimated_size_bytes,omitempty"` + Tables []TablePlan `json:"tables"` + CastRules []CastRule `json:"cast_rules"` + LoadOrder []string `json:"load_order"` + Issues []Issue `json:"issues"` +} + +// TablePlan describes a table in the migration plan. +type TablePlan struct { + Name string `json:"name"` + RowEstimate int `json:"row_estimate,omitempty"` + HasFK bool `json:"has_foreign_keys"` +} + +// CastRule maps SQLite types to Postgres casts for pgloader. +type CastRule struct { + SourceType string `json:"source_type"` + TargetType string `json:"target_type"` + Using string `json:"using,omitempty"` + Tables string `json:"tables,omitempty"` +} + +// ImportResult describes an import run. +type ImportResult struct { + MigrationID string `json:"migration_id"` + Method string `json:"method"` + DryRun bool `json:"dry_run"` + TablesLoaded int `json:"tables_loaded,omitempty"` + Timings *ImportTimings `json:"timings,omitempty"` + Lint *LintResult `json:"lint,omitempty"` + Plan *PlanResult `json:"plan,omitempty"` + CanProceed bool `json:"can_proceed"` +} + +// ImportTimings breaks down import wall-clock time by phase. +type ImportTimings struct { + TotalMs int64 `json:"total_ms"` + SQLiteStagingMs int64 `json:"sqlite_staging_ms,omitempty"` + SchemaMs int64 `json:"schema_ms,omitempty"` + PgloaderMs int64 `json:"pgloader_ms,omitempty"` + IndexBuildMs int64 `json:"index_build_ms,omitempty"` + SequenceResetMs int64 `json:"sequence_reset_ms,omitempty"` + TableLoads []TableLoadTiming `json:"table_loads,omitempty"` +} + +// TableLoadTiming is per-table pgloader duration. +type TableLoadTiming struct { + Table string `json:"table"` + Ms int64 `json:"ms"` +} + +// VerifyOptions configures post-import verification. +type VerifyOptions struct { + Org string + Database string + Branch string + MigrationID string + InputPath string + SQLitePath string + DestURI string + DBName string // destination PostgreSQL database name (default postgres) + NotifyAPI NotifyAPIConfig + OnProgress ImportProgressFunc + notifyBase importNotificationPayload +} + +// VerifyResult compares source and destination after import. +type VerifyResult struct { + MigrationID string `json:"migration_id"` + Matched bool `json:"matched"` + Tables []TableVerifyResult `json:"tables"` + Checks []VerifyCheckResult `json:"checks,omitempty"` +} + +// TableVerifyResult is per-table verification. +type TableVerifyResult struct { + Table string `json:"table"` + SourceRows int64 `json:"source_rows"` + DestRows int64 `json:"dest_rows"` + Match bool `json:"match"` +} + +// Migration phases persisted in local state. +const ( + PhasePlanned = "planned" + PhaseImporting = "importing" + PhaseImported = "imported" + PhaseVerified = "verified" + PhaseFailed = "failed" + PhaseComplete = "complete" +) + +// MigrationState is persisted local migration metadata. +type MigrationState struct { + MigrationID string `json:"migration_id"` + Org string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + InputPath string `json:"input_path"` + SQLitePath string `json:"sqlite_path,omitempty"` + DBName string `json:"db_name,omitempty"` + Method string `json:"method,omitempty"` + Phase string `json:"phase"` + SchemaApplied bool `json:"schema_applied,omitempty"` + LoadedTables []string `json:"loaded_tables,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/import/d1/verify.go b/internal/import/d1/verify.go new file mode 100644 index 000000000..00a4b9175 --- /dev/null +++ b/internal/import/d1/verify.go @@ -0,0 +1,336 @@ +package d1 + +import ( + "context" + "fmt" + "time" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +// Verify compares SQLite source data with PlanetScale Postgres after import. +func Verify(ctx context.Context, opts VerifyOptions) (result *VerifyResult, err error) { + verifyStart := time.Now() + verifyChecksPassed := false + defer func() { + if err == nil || verifyChecksPassed { + return + } + payload := importNotificationPayload{ + DurationMs: time.Since(verifyStart).Milliseconds(), + } + notifyImportFailure(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, payload, err, result) + }() + + if opts.DestURI == "" { + return nil, newMigrationError( + ErrCodeInvalidInput, + "destination database connection required for verify", + "Pass database and branch as positional arguments so verify can compare against PlanetScale Postgres", + ) + } + + opts, sqlitePath, err := resolveVerifySQLitePath(opts) + if err != nil { + return nil, err + } + + if opts.InputPath != "" && opts.SQLitePath == "" { + if err := EnsureSQLiteFromDump(ctx, opts.InputPath, sqlitePath); err != nil { + return nil, newMigrationError( + ErrCodeVerifyFailed, + fmt.Sprintf("build sqlite staging: %v", err), + "Ensure the dump is valid and sqlite3 is installed; pass --sqlite for a custom staging path", + ) + } + } + + dbName := opts.DBName + if dbName == "" && opts.MigrationID != "" { + if state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err == nil && state.DBName != "" { + dbName = state.DBName + } + } + if dbName == "" { + dbName = "postgres" + } + opts.DBName = dbName + + opts.notifyBase = notifyPayloadFromVerify(opts) + + tables, err := ParseDump(opts.InputPath) + if err != nil { + return nil, err + } + + coerceCtx, err := BuildTypeCoercionContext(opts.InputPath, tables) + if err != nil { + return nil, err + } + + tableNames := make([]string, 0, len(tables)) + dataTables := make([]TableSchema, 0, len(tables)) + for _, t := range tables { + if IsORMMetadataTable(t.Name) { + continue + } + tableNames = append(tableNames, t.Name) + dataTables = append(dataTables, t) + } + + NotifyImportEventSync(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, NotifyEventVerifying, importNotificationPayload{}) + + opts.reportProgress(ImportProgress{Stage: VerifyStageRowCounts, Total: len(tableNames)}) + sourceCounts, err := countSQLiteRowsWithProgress(ctx, opts, sqlitePath, tableNames) + if err != nil { + return nil, newMigrationError( + ErrCodeVerifyFailed, + fmt.Sprintf("count source rows: %v", err), + "Ensure sqlite3 is installed and the staging database is readable; pass --sqlite if using a custom path", + ) + } + + destCounts, err := countPostgresRowsWithProgress(ctx, opts, tableNames) + if err != nil { + return nil, err + } + destCounts, extraTables, err := mergeImportScopedDestRowCounts(ctx, opts, tableNames, destCounts) + if err != nil { + return nil, err + } + + result = &VerifyResult{ + MigrationID: opts.MigrationID, + Matched: true, + Checks: []VerifyCheckResult{}, + } + + verifyTables := append(append([]string{}, tableNames...), extraTables...) + var rowCountsOK bool + result.Tables, rowCountsOK = verifyRowCounts(verifyTables, sourceCounts, destCounts) + if !rowCountsOK { + result.Matched = false + } + + db, err := OpenPostgres(opts.DestURI) + if err != nil { + return nil, err + } + defer db.Close() + + opts.reportProgress(ImportProgress{Stage: VerifyStageSequences}) + seqChecks, ok := verifyIdentitySequences(ctx, db, dataTables) + result.Checks = append(result.Checks, seqChecks...) + if !ok { + result.Matched = false + } + + opts.reportProgress(ImportProgress{Stage: VerifyStageBoolean}) + boolChecks, ok, err := verifyBooleanColumns(ctx, db, sqlitePath, dataTables, coerceCtx) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, boolChecks...) + if !ok { + result.Matched = false + } + + opts.reportProgress(ImportProgress{Stage: VerifyStageFingerprints}) + fpChecks, ok, err := verifyTableFingerprints(ctx, db, sqlitePath, dataTables, coerceCtx) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, fpChecks...) + if !ok { + result.Matched = false + } + + opts.reportProgress(ImportProgress{Stage: VerifyStageSampleRows}) + sampleChecks, ok, err := verifySampleRows(ctx, db, sqlitePath, dataTables, coerceCtx, 8, 3) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, sampleChecks...) + if !ok { + result.Matched = false + } + + if !result.Matched { + return result, newMigrationError( + ErrCodeVerifyFailed, + "import verification failed (row counts, sequences, coercion, or content checks)", + "Re-run import or inspect failing checks in verify JSON output", + ) + } + + verifyChecksPassed = true + + if opts.MigrationID != "" { + if err := SetMigrationPhase(opts.Org, opts.Database, opts.Branch, opts.MigrationID, PhaseVerified); err != nil { + return result, errStatePersist("verify", err) + } + } + + if !opts.NotifyAPI.Disabled && opts.NotifyAPI.Client != nil { + matched := result.Matched + NotifyImportEventSync(opts.NotifyAPI, opts.Org, opts.Database, opts.Branch, opts.MigrationID, NotifyEventVerified, importNotificationPayload{ + Matched: &matched, + DurationMs: time.Since(verifyStart).Milliseconds(), + }) + } + + return result, nil +} + +// CountSQLiteRows counts rows using sqlite3 CLI. +func CountSQLiteRows(ctx context.Context, sqlitePath string, tables []string) (map[string]int64, error) { + return countSQLiteRowsWithProgress(ctx, VerifyOptions{}, sqlitePath, tables) +} + +func countSQLiteRowsWithProgress(ctx context.Context, opts VerifyOptions, sqlitePath string, tables []string) (map[string]int64, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return nil, err + } + + counts := make(map[string]int64, len(tables)) + for i, table := range tables { + opts.reportProgress(ImportProgress{ + Stage: VerifyStageRowCounts, + Current: i + 1, + Total: len(tables), + Detail: table + " (sqlite)", + }) + query := fmt.Sprintf("SELECT COUNT(*) FROM %q;", table) + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, query) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("sqlite count %s: %w", table, err) + } + var count int64 + if _, err := fmt.Sscanf(string(out), "%d", &count); err != nil { + return nil, err + } + counts[table] = count + } + return counts, nil +} + +// CountPostgresRows counts rows in public schema tables. +func CountPostgresRows(ctx context.Context, destURI string, tables []string) (map[string]int64, error) { + return countPostgresRowsWithProgress(ctx, VerifyOptions{DestURI: destURI}, tables) +} + +func countPostgresRowsWithProgress(ctx context.Context, opts VerifyOptions, tables []string) (map[string]int64, error) { + db, err := OpenPostgres(opts.DestURI) + if err != nil { + return nil, err + } + defer db.Close() + + counts := make(map[string]int64, len(tables)) + for i, table := range tables { + opts.reportProgress(ImportProgress{ + Stage: VerifyStageRowCounts, + Current: i + 1, + Total: len(tables), + Detail: table + " (postgres)", + }) + var count int64 + query := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, postgres.QuoteIdentifier(table)) + if err := db.QueryRowContext(ctx, query).Scan(&count); err != nil { + return nil, fmt.Errorf("count %s: %w", table, err) + } + counts[table] = count + } + return counts, nil +} + +func mergeImportScopedDestRowCounts(ctx context.Context, opts VerifyOptions, sourceTables []string, destCounts map[string]int64) (map[string]int64, []string, error) { + if opts.MigrationID == "" { + return destCounts, nil, nil + } + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return destCounts, nil, nil + } + if len(state.LoadedTables) == 0 { + return destCounts, nil, nil + } + + sourceSet := make(map[string]struct{}, len(sourceTables)) + for _, name := range sourceTables { + sourceSet[name] = struct{}{} + } + + var extra []string + for _, name := range state.LoadedTables { + if _, ok := sourceSet[name]; ok { + continue + } + extra = append(extra, name) + } + if len(extra) == 0 { + return destCounts, nil, nil + } + + db, err := OpenPostgres(opts.DestURI) + if err != nil { + return nil, nil, err + } + defer db.Close() + + if destCounts == nil { + destCounts = make(map[string]int64) + } + for _, name := range extra { + var count int64 + query := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, postgres.QuoteIdentifier(name)) + if err := db.QueryRowContext(ctx, query).Scan(&count); err != nil { + return nil, nil, fmt.Errorf("count import-scoped table %s: %w", name, err) + } + destCounts[name] = count + } + return destCounts, extra, nil +} + +func resolveVerifySQLitePath(opts VerifyOptions) (VerifyOptions, string, error) { + if opts.SQLitePath != "" { + return opts, opts.SQLitePath, nil + } + + if opts.MigrationID != "" { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return opts, "", err + } + if err := validateInputPathAgainstState(opts.InputPath, state.InputPath); err != nil { + return opts, "", err + } + if opts.InputPath == "" { + opts.InputPath = state.InputPath + } + if state.SQLitePath != "" { + return opts, state.SQLitePath, nil + } + if opts.InputPath == "" { + return opts, "", newMigrationError( + ErrCodeMissingInput, + "input dump path required for verify", + "Pass --input or run verify with a migration-id from a prior import", + ) + } + return opts, DefaultSQLitePath(opts.InputPath), nil + } + + if opts.InputPath == "" { + return opts, "", newMigrationError( + ErrCodeMissingInput, + "input dump path required for verify", + "Pass --input or run verify with a migration-id from a prior import", + ) + } + + return opts, DefaultSQLitePath(opts.InputPath), nil +} diff --git a/internal/import/d1/verify_checks.go b/internal/import/d1/verify_checks.go new file mode 100644 index 000000000..8d6aa323d --- /dev/null +++ b/internal/import/d1/verify_checks.go @@ -0,0 +1,725 @@ +package d1 + +import ( + "context" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + "unicode/utf8" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +// VerifyCheckResult is a single post-import verification check. +type VerifyCheckResult struct { + Name string `json:"name"` + Table string `json:"table,omitempty"` + Column string `json:"column,omitempty"` + Matched bool `json:"matched"` + Message string `json:"message,omitempty"` + Source string `json:"source,omitempty"` + Dest string `json:"dest,omitempty"` +} + +const verifySignatureFieldMaxLen = 64 + +// summarizeRowSignatureForOutput shortens row signatures for CLI/JSON output. +// Full signatures are still used internally for comparison. +func summarizeRowSignatureForOutput(sig string, table TableSchema) string { + parts := strings.Split(sig, "|") + if len(parts) != len(table.Columns) { + return truncateSignatureValue(sig, verifySignatureFieldMaxLen*4, false) + } + for i, part := range parts { + parts[i] = summarizeSignatureField(part, table.Columns[i]) + } + return strings.Join(parts, "|") +} + +func summarizeSignatureField(value string, col ColumnSchema) string { + if value == "" { + return value + } + if isBlobColumn(col) { + byteLen := len(value) / 2 + if utf8.ValidString(value) && len(value)%2 == 0 { + if _, err := hex.DecodeString(value); err == nil && len(value) > verifySignatureFieldMaxLen { + return truncateSignatureValue(value, verifySignatureFieldMaxLen, true) + + fmt.Sprintf(" (%d bytes)", byteLen) + } + } + } + return truncateSignatureValue(value, verifySignatureFieldMaxLen, true) +} + +func truncateSignatureValue(value string, maxLen int, addEllipsis bool) string { + if len(value) <= maxLen { + return value + } + if addEllipsis { + return value[:maxLen] + "..." + } + return value[:maxLen] +} + +type tableFingerprint struct { + RowCount int64 + IDSum string +} + +type booleanDistribution struct { + TrueCount int64 + FalseCount int64 + NullCount int64 +} + +func verifyRowCounts(tableNames []string, sourceCounts, destCounts map[string]int64) ([]TableVerifyResult, bool) { + results := make([]TableVerifyResult, 0, len(tableNames)) + matched := true + for _, name := range tableNames { + ok := sourceCounts[name] == destCounts[name] + if !ok { + matched = false + } + results = append(results, TableVerifyResult{ + Table: name, + SourceRows: sourceCounts[name], + DestRows: destCounts[name], + Match: ok, + }) + } + return results, matched +} + +func verifyIdentitySequences(ctx context.Context, db *sql.DB, tables []TableSchema) ([]VerifyCheckResult, bool) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !col.AutoIncrement { + continue + } + check, ok, err := verifyTableSequence(ctx, db, table.Name, col.Name) + if err != nil { + checks = append(checks, VerifyCheckResult{ + Name: "sequences", + Table: table.Name, + Column: col.Name, + Matched: false, + Message: err.Error(), + }) + matched = false + continue + } + checks = append(checks, check) + if !ok { + matched = false + } + } + } + return checks, matched +} + +func verifyTableSequence(ctx context.Context, db *sql.DB, table, column string) (VerifyCheckResult, bool, error) { + check := VerifyCheckResult{ + Name: "sequences", + Table: table, + Column: column, + } + + var maxID sql.NullInt64 + maxQuery := fmt.Sprintf(`SELECT MAX(%s) FROM %s`, postgres.QuoteIdentifier(column), postgres.QuoteIdentifier(table)) + if err := db.QueryRowContext(ctx, maxQuery).Scan(&maxID); err != nil { + return check, false, fmt.Errorf("max %s.%s: %w", table, column, err) + } + if !maxID.Valid { + check.Matched = true + check.Message = "empty table" + return check, true, nil + } + + var seqName sql.NullString + if err := db.QueryRowContext(ctx, + `SELECT pg_get_serial_sequence($1, $2)`, + "public."+table, + column, + ).Scan(&seqName); err != nil { + return check, false, fmt.Errorf("sequence lookup %s.%s: %w", table, column, err) + } + if !seqName.Valid || seqName.String == "" { + check.Matched = true + check.Message = "no sequence attached (non-identity column)" + return check, true, nil + } + + var lastValue int64 + var isCalled bool + seqQuery := fmt.Sprintf(`SELECT last_value, is_called FROM %s`, seqName.String) + if err := db.QueryRowContext(ctx, seqQuery).Scan(&lastValue, &isCalled); err != nil { + return check, false, fmt.Errorf("read sequence %s: %w", seqName.String, err) + } + + nextValue := lastValue + if isCalled { + nextValue = lastValue + 1 + } + ok := maxID.Int64 < nextValue + check.Matched = ok + check.Source = fmt.Sprintf("max=%d", maxID.Int64) + check.Dest = fmt.Sprintf("next=%d (last_value=%d is_called=%t)", nextValue, lastValue, isCalled) + if !ok { + check.Message = "sequence next value would collide with existing rows" + } else { + check.Message = "sequence ready for new inserts" + } + return check, ok, nil +} + +func verifyBooleanColumns(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema, coerceCtx *TypeCoercionContext) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !isBooleanLikeColumn(col, table, coerceCtx) { + continue + } + src, err := sqliteBooleanDistribution(ctx, sqlitePath, table.Name, col.Name) + if err != nil { + return checks, false, err + } + dest, err := postgresBooleanDistribution(ctx, db, table.Name, col.Name) + if err != nil { + return checks, false, err + } + ok := src.TrueCount == dest.TrueCount && src.FalseCount == dest.FalseCount && src.NullCount == dest.NullCount + check := VerifyCheckResult{ + Name: "boolean_columns", + Table: table.Name, + Column: col.Name, + Matched: ok, + Source: fmt.Sprintf("true=%d false=%d null=%d", src.TrueCount, src.FalseCount, src.NullCount), + Dest: fmt.Sprintf("true=%d false=%d null=%d", dest.TrueCount, dest.FalseCount, dest.NullCount), + } + if !ok { + check.Message = "boolean value distribution mismatch after import" + matched = false + } else { + check.Message = "boolean coercion matches source 0/1 distribution" + } + checks = append(checks, check) + } + } + return checks, matched, nil +} + +func sqliteBooleanDistribution(ctx context.Context, sqlitePath, table, column string) (booleanDistribution, error) { + query := fmt.Sprintf( + `SELECT SUM(CASE WHEN %q = 1 THEN 1 ELSE 0 END), SUM(CASE WHEN %q = 0 THEN 1 ELSE 0 END), SUM(CASE WHEN %q IS NULL THEN 1 ELSE 0 END) FROM %q;`, + column, column, column, table, + ) + return querySQLiteDistribution(ctx, sqlitePath, query) +} + +func postgresBooleanDistribution(ctx context.Context, db *sql.DB, table, column string) (booleanDistribution, error) { + query := fmt.Sprintf( + `SELECT COUNT(*) FILTER (WHERE %s = TRUE), COUNT(*) FILTER (WHERE %s = FALSE), COUNT(*) FILTER (WHERE %s IS NULL) FROM %s`, + postgres.QuoteIdentifier(column), postgres.QuoteIdentifier(column), postgres.QuoteIdentifier(column), postgres.QuoteIdentifier(table), + ) + var dist booleanDistribution + if err := db.QueryRowContext(ctx, query).Scan(&dist.TrueCount, &dist.FalseCount, &dist.NullCount); err != nil { + return dist, err + } + return dist, nil +} + +func querySQLiteDistribution(ctx context.Context, sqlitePath, query string) (booleanDistribution, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return booleanDistribution{}, err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return booleanDistribution{}, err + } + parts := parseSQLiteCLIFields(out) + if len(parts) < 3 { + return booleanDistribution{}, fmt.Errorf("unexpected boolean count output: %q", string(out)) + } + var dist booleanDistribution + for i, ptr := range []*int64{&dist.TrueCount, &dist.FalseCount, &dist.NullCount} { + if parts[i] == "" || parts[i] == "NULL" { + continue + } + if _, err := fmt.Sscanf(parts[i], "%d", ptr); err != nil { + return booleanDistribution{}, err + } + } + return dist, nil +} + +// parseSQLiteCLIFields splits sqlite3 CLI output. Multi-column results use '|'. +func parseSQLiteCLIFields(out []byte) []string { + s := strings.TrimSpace(string(out)) + if s == "" { + return nil + } + if strings.Contains(s, "|") { + parts := strings.Split(s, "|") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts + } + return strings.Fields(s) +} + +func verifyTableFingerprints(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema, coerceCtx *TypeCoercionContext) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + pkCol := identityColumn(table) + src, err := tableFingerprintFromSQLite(ctx, sqlitePath, table, pkCol, tables, coerceCtx) + if err != nil { + return checks, false, err + } + dest, err := tableFingerprintFromPostgres(ctx, db, table, pkCol, tables, coerceCtx) + if err != nil { + return checks, false, err + } + ok := src.RowCount == dest.RowCount && src.IDSum == dest.IDSum + check := VerifyCheckResult{ + Name: "table_fingerprint", + Table: table.Name, + Matched: ok, + Source: fmt.Sprintf("rows=%d id_sum=%s", src.RowCount, src.IDSum), + Dest: fmt.Sprintf("rows=%d id_sum=%s", dest.RowCount, dest.IDSum), + } + if !ok { + check.Message = "aggregate fingerprint mismatch" + matched = false + } else if shouldFingerprintPKSum(table, pkCol, tables, coerceCtx) { + check.Message = "row count and integer PK sum match" + } else { + check.Message = "row count match" + } + checks = append(checks, check) + } + return checks, matched, nil +} + +func identityColumn(table TableSchema) string { + if cols := primaryKeyColumns(table); len(cols) == 1 { + return cols[0] + } + return "" +} + +func primaryKeyColumns(table TableSchema) []string { + var pks []string + for _, col := range table.Columns { + if col.PrimaryKey { + pks = append(pks, col.Name) + } + } + if len(pks) > 0 { + return pks + } + for _, col := range table.Columns { + if col.AutoIncrement { + return []string{col.Name} + } + } + return nil +} + +func shouldFingerprintPKSum(table TableSchema, pkCol string, all []TableSchema, coerceCtx *TypeCoercionContext) bool { + if pkCol == "" { + return false + } + col := columnByName(table, pkCol) + if col.Name == "" { + return false + } + if isUUIDColumn(col, table, all, coerceCtx) { + return false + } + upper := strings.ToUpper(col.Type) + return col.AutoIncrement || strings.Contains(upper, "INT") +} + +func tableFingerprintFromSQLite(ctx context.Context, sqlitePath string, table TableSchema, pkCol string, all []TableSchema, coerceCtx *TypeCoercionContext) (tableFingerprint, error) { + var query string + if pkCol != "" && shouldFingerprintPKSum(table, pkCol, all, coerceCtx) { + query = fmt.Sprintf(`SELECT COUNT(*), COALESCE(CAST(SUM(CAST(%q AS INTEGER)) AS TEXT), '0') FROM %q;`, pkCol, table.Name) + } else { + query = fmt.Sprintf(`SELECT COUNT(*), '0' FROM %q;`, table.Name) + } + sqlite3, err := FindSQLite3() + if err != nil { + return tableFingerprint{}, err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s: %w", table.Name, err) + } + var fp tableFingerprint + fields := parseSQLiteCLIFields(out) + if len(fields) < 2 { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s: unexpected output %q", table.Name, string(out)) + } + if _, err := fmt.Sscanf(fields[0], "%d", &fp.RowCount); err != nil { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s row count: %w", table.Name, err) + } + fp.IDSum = fields[1] + return fp, nil +} + +func tableFingerprintFromPostgres(ctx context.Context, db *sql.DB, table TableSchema, pkCol string, all []TableSchema, coerceCtx *TypeCoercionContext) (tableFingerprint, error) { + var fp tableFingerprint + var query string + if pkCol != "" && shouldFingerprintPKSum(table, pkCol, all, coerceCtx) { + query = fmt.Sprintf(`SELECT COUNT(*), COALESCE(SUM(%s::numeric)::text, '0') FROM %s`, postgres.QuoteIdentifier(pkCol), postgres.QuoteIdentifier(table.Name)) + } else { + query = fmt.Sprintf(`SELECT COUNT(*), '0' FROM %s`, postgres.QuoteIdentifier(table.Name)) + } + if err := db.QueryRowContext(ctx, query).Scan(&fp.RowCount, &fp.IDSum); err != nil { + return fp, fmt.Errorf("postgres fingerprint %s: %w", table.Name, err) + } + return fp, nil +} + +func verifySampleRows(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema, coerceCtx *TypeCoercionContext, maxTables, samplesPerTable int) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + checked := 0 + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + if checked >= maxTables { + break + } + pkCols := primaryKeyColumns(table) + if len(pkCols) != 1 { + checks = append(checks, VerifyCheckResult{ + Name: "sample_rows", + Table: table.Name, + Message: "skipped (requires single-column primary key for row sampling)", + }) + continue + } + pkCol := pkCols[0] + if pkCol == "" { + continue + } + ids, err := samplePrimaryKeys(ctx, sqlitePath, table.Name, pkCol, samplesPerTable) + if err != nil { + return checks, false, err + } + if len(ids) == 0 { + continue + } + checked++ + + for _, id := range ids { + src, err := sqliteRowSignature(ctx, sqlitePath, table, pkCol, id, coerceCtx) + if err != nil { + return checks, false, err + } + dest, err := postgresRowSignature(ctx, db, table, pkCol, id, tables, coerceCtx) + if err != nil { + return checks, false, err + } + ok := rowSignaturesMatch(src, dest, table, tables, coerceCtx) + check := VerifyCheckResult{ + Name: "sample_rows", + Table: table.Name, + Column: pkCol, + Matched: ok, + } + if !ok { + check.Source = summarizeRowSignatureForOutput(src, table) + check.Dest = summarizeRowSignatureForOutput(dest, table) + check.Message = fmt.Sprintf("row signature mismatch for %s=%s", pkCol, id) + matched = false + } else { + check.Message = fmt.Sprintf("row signature match for %s=%s", pkCol, id) + } + checks = append(checks, check) + } + } + return checks, matched, nil +} + +func samplePrimaryKeys(ctx context.Context, sqlitePath, table, pkCol string, limit int) ([]string, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return nil, err + } + query := fmt.Sprintf(`SELECT %q FROM %q ORDER BY %q LIMIT %d;`, pkCol, table, pkCol, limit) + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return nil, err + } + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + var ids []string + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + ids = append(ids, line) + } + } + return ids, nil +} + +func sqliteSignatureColumnExpr(col ColumnSchema, table TableSchema, coerceCtx *TypeCoercionContext) string { + if isBooleanLikeColumn(col, table, coerceCtx) { + return fmt.Sprintf(`CASE WHEN %q IN (1, '1') THEN '1' WHEN %q IN (0, '0') THEN '0' ELSE '' END`, col.Name, col.Name) + } + if isJSONText(col) && coerceCtx != nil && samplesAllowJSON(table.Name, col.Name, coerceCtx) { + return fmt.Sprintf(`COALESCE(json(%q), CAST(%q AS TEXT), '')`, col.Name, col.Name) + } + if isBlobColumn(col) { + return fmt.Sprintf(`COALESCE(hex(%q), '')`, col.Name) + } + if isTimestampText(col) && coerceCtx != nil && samplesAllowTimestamp(table.Name, col.Name, coerceCtx) { + return fmt.Sprintf(`COALESCE(strftime('%%Y-%%m-%%dT%%H:%%M:%%SZ', %q), COALESCE(CAST(%q AS TEXT), ''))`, col.Name, col.Name) + } + return fmt.Sprintf(`COALESCE(CAST(%q AS TEXT), '')`, col.Name) +} + +func postgresSignatureColumnExpr(col ColumnSchema, table TableSchema, all []TableSchema, coerceCtx *TypeCoercionContext) string { + pgType := sqliteTypeToPostgres(col, table, all, coerceCtx) + switch pgType { + case "BOOLEAN": + name := postgres.QuoteIdentifier(col.Name) + return fmt.Sprintf(`CASE WHEN %s IS TRUE THEN '1' WHEN %s IS FALSE THEN '0' ELSE '' END`, name, name) + case "TIMESTAMPTZ": + name := postgres.QuoteIdentifier(col.Name) + return fmt.Sprintf(`COALESCE(to_char(%s AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"'), '')`, name) + case "JSONB": + name := postgres.QuoteIdentifier(col.Name) + return fmt.Sprintf(`COALESCE(%s::jsonb::text, '')`, name) + case "BYTEA": + name := postgres.QuoteIdentifier(col.Name) + return fmt.Sprintf(`COALESCE(encode(%s, 'hex'), '')`, name) + default: + return fmt.Sprintf(`COALESCE(%s::text, '')`, postgres.QuoteIdentifier(col.Name)) + } +} + +func rowSignaturesMatch(src, dest string, table TableSchema, all []TableSchema, coerceCtx *TypeCoercionContext) bool { + srcParts := strings.Split(src, "|") + destParts := strings.Split(dest, "|") + if len(srcParts) != len(destParts) || len(srcParts) != len(table.Columns) { + return src == dest + } + for i, col := range table.Columns { + pgType := sqliteTypeToPostgres(col, table, all, coerceCtx) + switch pgType { + case "JSONB": + if !jsonValuesEqual(srcParts[i], destParts[i]) { + return false + } + case "BYTEA": + if !byteaValuesEqual(srcParts[i], destParts[i]) { + return false + } + case "TIMESTAMPTZ": + if !timestampValuesEqual(srcParts[i], destParts[i]) { + return false + } + default: + if looksLikeJSON(srcParts[i]) && looksLikeJSON(destParts[i]) { + if !jsonValuesEqual(srcParts[i], destParts[i]) { + return false + } + continue + } + if srcParts[i] != destParts[i] { + return false + } + } + } + return true +} + +func jsonValuesEqual(a, b string) bool { + ca, errA := canonicalJSON(a) + cb, errB := canonicalJSON(b) + if errA == nil && errB == nil { + return ca == cb + } + if errA != nil && errB != nil { + if looksLikeJSON(a) || looksLikeJSON(b) { + return false + } + return a == b + } + return false +} + +func canonicalJSON(s string) (string, error) { + if s == "" { + return "", nil + } + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return "", err + } + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} + +func byteaValuesEqual(sqliteText, pgText string) bool { + if sqliteText == pgText { + return true + } + if strings.EqualFold(sqliteText, pgText) { + return true + } + if a, okA := decodeByteaSignature(sqliteText); okA { + if b, okB := decodeByteaSignature(pgText); okB { + return string(a) == string(b) + } + } + if strings.HasPrefix(pgText, `\x`) { + decoded, err := hex.DecodeString(strings.TrimPrefix(pgText, `\x`)) + if err == nil && sqliteText == string(decoded) { + return true + } + } + return false +} + +func decodeByteaSignature(s string) ([]byte, bool) { + s = strings.TrimPrefix(s, `\x`) + b, err := hex.DecodeString(s) + if err != nil { + return nil, false + } + return b, true +} + +func isBlobColumn(col ColumnSchema) bool { + return strings.Contains(strings.ToUpper(col.Type), "BLOB") +} + +func timestampValuesEqual(a, b string) bool { + if a == b { + return true + } + if a == "" || b == "" { + return false + } + return normalizeTimestamp(a) == normalizeTimestamp(b) +} + +func normalizeTimestamp(s string) string { + s = strings.TrimSpace(s) + s = strings.Replace(s, " ", "T", 1) + if strings.HasSuffix(s, "Z") { + return s + } + if idx := strings.LastIndexAny(s, "+-"); idx > 10 { + return s + } + if strings.Contains(s, "T") { + return s + "Z" + } + return s +} + +func sqliteRowSignature(ctx context.Context, sqlitePath string, table TableSchema, pkCol, pkVal string, coerceCtx *TypeCoercionContext) (string, error) { + cols := make([]string, 0, len(table.Columns)) + for _, col := range table.Columns { + cols = append(cols, sqliteSignatureColumnExpr(col, table, coerceCtx)) + } + query := fmt.Sprintf( + `SELECT %s FROM %q WHERE %q = %s LIMIT 1;`, + strings.Join(cols, " || '|' || "), + table.Name, + pkCol, + sqliteLiteral(pkVal), + ) + sqlite3, err := FindSQLite3() + if err != nil { + return "", err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +func postgresRowSignature(ctx context.Context, db *sql.DB, table TableSchema, pkCol, pkVal string, all []TableSchema, coerceCtx *TypeCoercionContext) (string, error) { + cols := make([]string, 0, len(table.Columns)) + for _, col := range table.Columns { + cols = append(cols, postgresSignatureColumnExpr(col, table, all, coerceCtx)) + } + query := fmt.Sprintf( + `SELECT %s FROM %s WHERE %s = $1 LIMIT 1`, + strings.Join(cols, " || '|' || "), + postgres.QuoteIdentifier(table.Name), + postgres.QuoteIdentifier(pkCol), + ) + var sig sql.NullString + if err := db.QueryRowContext(ctx, query, pkVal).Scan(&sig); err != nil { + return "", err + } + if !sig.Valid { + return "", fmt.Errorf("row not found in %s where %s = %s", table.Name, pkCol, pkVal) + } + return sig.String, nil +} + +func sqliteLiteral(val string) string { + if val != "" && isSQLiteIntegerLiteral(val) { + return val + } + return "'" + strings.ReplaceAll(val, "'", "''") + "'" +} + +func isSQLiteIntegerLiteral(val string) bool { + if val == "" || val[0] == '-' { + if len(val) <= 1 { + return false + } + val = val[1:] + } + for i := 0; i < len(val); i++ { + if val[i] < '0' || val[i] > '9' { + return false + } + } + return true +} + +func runSQLiteQuery(ctx context.Context, sqlite3, sqlitePath, query string) ([]byte, error) { + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, query) + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("%w: %s", err, strings.TrimSpace(string(out))) + } + return out, nil +} diff --git a/internal/import/d1/verify_checks_test.go b/internal/import/d1/verify_checks_test.go new file mode 100644 index 000000000..e6f15d0e1 --- /dev/null +++ b/internal/import/d1/verify_checks_test.go @@ -0,0 +1,221 @@ +package d1 + +import ( + "encoding/hex" + "strings" + "testing" +) + +func TestVerifyRowCounts(t *testing.T) { + source := map[string]int64{"users": 2, "posts": 2} + dest := map[string]int64{"users": 2, "posts": 1} + + results, ok := verifyRowCounts([]string{"users", "posts"}, source, dest) + if ok { + t.Fatal("expected mismatch") + } + if len(results) != 2 { + t.Fatalf("expected 2 table results, got %d", len(results)) + } + if !results[0].Match || results[1].Match { + t.Fatalf("unexpected match flags: %+v", results) + } +} + +func TestVerifyRowCountsIncludesImportScopedDestTables(t *testing.T) { + source := map[string]int64{"users": 1} + dest := map[string]int64{"users": 1, "legacy_import_table": 5} + results, ok := verifyRowCounts([]string{"users", "legacy_import_table"}, source, dest) + if ok { + t.Fatal("expected mismatch when import-scoped dest table has rows but source does not") + } + found := false + for _, r := range results { + if r.Table == "legacy_import_table" && !r.Match && r.DestRows == 5 && r.SourceRows == 0 { + found = true + } + } + if !found { + t.Fatalf("expected import-scoped dest table mismatch, got %+v", results) + } +} + +func TestColumnReferencesUUIDKey(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatalf("ParseDump: %v", err) + } + coerceCtx, err := BuildTypeCoercionContext(testFixture(t), tables) + if err != nil { + t.Fatalf("BuildTypeCoercionContext: %v", err) + } + + var entityLinks TableSchema + for _, table := range tables { + if table.Name == "entity_links" { + entityLinks = table + break + } + } + if entityLinks.Name == "" { + t.Fatal("missing entity_links table") + } + + var entityID ColumnSchema + for _, col := range entityLinks.Columns { + if col.Name == "entity_id" { + entityID = col + break + } + } + if entityID.Name == "" { + t.Fatal("missing entity_id column") + } + if !columnReferencesUUIDKey(entityID, entityLinks, tables, coerceCtx) { + t.Fatal("expected entity_id to reference UUID primary key") + } + if isExplicitUUIDColumn(entityID) { + t.Fatal("entity_id should not be treated as explicit UUID column") + } +} + +func TestColumnReferencesUUIDKeyCycle(t *testing.T) { + tables := []TableSchema{ + { + Name: "nodes_a", + Columns: []ColumnSchema{{ + Name: "next_id", + Type: "TEXT", + ForeignKey: `REFERENCES nodes_b(id)`, + }}, + }, + { + Name: "nodes_b", + Columns: []ColumnSchema{{ + Name: "next_id", + Type: "TEXT", + ForeignKey: `REFERENCES nodes_a(id)`, + }}, + }, + } + if columnReferencesUUIDKey(tables[0].Columns[0], tables[0], tables, nil) { + t.Fatal("expected cyclic FK chain to resolve as non-UUID without stack overflow") + } +} + +func TestLooksLikeRailsSchemaMigrations(t *testing.T) { + rails := TableSchema{ + Name: "schema_migrations", + Columns: []ColumnSchema{{ + Name: "version", + Type: "VARCHAR(255)", + }}, + } + if !looksLikeRailsSchemaMigrations(rails) { + t.Fatal("expected rails-like schema_migrations") + } + + appTable := TableSchema{ + Name: "schema_migrations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true}, + {Name: "name", Type: "TEXT"}, + }, + } + if looksLikeRailsSchemaMigrations(appTable) { + t.Fatal("expected app schema_migrations to differ from rails layout") + } +} + +func TestParseSQLiteCLIFields(t *testing.T) { + got := parseSQLiteCLIFields([]byte("120|0|0\n")) + if len(got) != 3 || got[0] != "120" || got[1] != "0" || got[2] != "0" { + t.Fatalf("parseSQLiteCLIFields() = %v", got) + } + got = parseSQLiteCLIFields([]byte("94400 123456\n")) + if len(got) != 2 || got[0] != "94400" { + t.Fatalf("parseSQLiteCLIFields() = %v", got) + } +} + +func TestJSONValuesEqual(t *testing.T) { + a := `{"priority": 0, "labels": ["seed"]}` + b := `{"labels": ["seed"], "priority": 0}` + if !jsonValuesEqual(a, b) { + t.Fatal("expected equivalent JSON objects to match") + } + if jsonValuesEqual(a, `{"priority": 1}`) { + t.Fatal("expected different JSON objects to mismatch") + } +} + +func TestByteaValuesEqual(t *testing.T) { + text := "attachment-1-payload" + hex := `\x` + hex.EncodeToString([]byte(text)) + if !byteaValuesEqual(text, hex) { + t.Fatalf("expected bytea hex %q to match text %q", hex, text) + } +} + +func TestByteaSignatureExprsUseHex(t *testing.T) { + col := ColumnSchema{Name: "payload", Type: "BLOB"} + table := TableSchema{Name: "attachments", Columns: []ColumnSchema{col}} + + sqliteExpr := sqliteSignatureColumnExpr(col, table, nil) + if !strings.Contains(sqliteExpr, "hex(") { + t.Fatalf("sqlite blob signature should use hex(), got %q", sqliteExpr) + } + + pgExpr := postgresSignatureColumnExpr(col, table, nil, nil) + if !strings.Contains(pgExpr, "encode(") || !strings.Contains(pgExpr, "'hex'") { + t.Fatalf("postgres bytea signature should use encode(..., 'hex'), got %q", pgExpr) + } +} + +func TestTimestampValuesEqual(t *testing.T) { + if !timestampValuesEqual("2024-01-15 12:00:00", "2024-01-15T12:00:00Z") { + t.Fatal("expected space and ISO timestamp forms to match") + } + if timestampValuesEqual("2024-01-15 12:00:00", "2024-01-16 12:00:00") { + t.Fatal("expected different timestamps to mismatch") + } +} + +func TestByteaValuesEqualBinaryHex(t *testing.T) { + raw := string([]byte{0x00, 0xff, 0xfe, 0x01}) + hexSig := hex.EncodeToString([]byte(raw)) + if !byteaValuesEqual(hexSig, hexSig) { + t.Fatalf("expected matching hex signatures for binary blob") + } + if !byteaValuesEqual(strings.ToUpper(hexSig), strings.ToLower(hexSig)) { + t.Fatalf("expected hex signatures to match regardless of case") + } +} + +func TestSummarizeRowSignatureForOutputOmitsBlobPayload(t *testing.T) { + table := TableSchema{ + Name: "attachments", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "task_id", Type: "INTEGER"}, + {Name: "filename", Type: "TEXT"}, + {Name: "mime_type", Type: "TEXT"}, + {Name: "size_bytes", Type: "INTEGER"}, + {Name: "checksum", Type: "TEXT"}, + {Name: "payload", Type: "BLOB"}, + }, + } + longHex := strings.Repeat("ab", 200) + sig := strings.Join([]string{"3", "3", "file-3.bin", "application/octet-stream", "47952", "sha256:00000003", longHex}, "|") + + got := summarizeRowSignatureForOutput(sig, table) + if strings.Contains(got, longHex) { + t.Fatalf("expected blob hex to be truncated, got %q", got) + } + if !strings.Contains(got, "file-3.bin") { + t.Fatalf("expected non-blob fields preserved, got %q", got) + } + if !strings.Contains(got, "(200 bytes)") { + t.Fatalf("expected blob byte count in summary, got %q", got) + } +} diff --git a/internal/import/d1/verify_test.go b/internal/import/d1/verify_test.go new file mode 100644 index 000000000..6b9b639b8 --- /dev/null +++ b/internal/import/d1/verify_test.go @@ -0,0 +1,116 @@ +package d1 + +import ( + "context" + "strings" + "testing" +) + +func TestVerifyFailsWhenSourceCountsUnavailable(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + _, err := Verify(context.Background(), VerifyOptions{ + DestURI: "postgresql://u:p@localhost:5432/postgres?sslmode=disable", + InputPath: testFixture(t), + SQLitePath: "/nonexistent/staging.sqlite", + MigrationID: "verify001", + }) + if err == nil { + t.Fatal("expected error for missing sqlite staging db") + } + migrationErr, ok := err.(*MigrationError) + if !ok { + t.Fatalf("expected MigrationError, got %T: %v", err, err) + } + if migrationErr.Info.Code != ErrCodeVerifyFailed { + t.Fatalf("code = %q, want %q", migrationErr.Info.Code, ErrCodeVerifyFailed) + } + if !strings.Contains(migrationErr.Info.Message, "count source rows") { + t.Fatalf("message = %q", migrationErr.Info.Message) + } +} + +func TestVerifyUsesDBNameFromState(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "verify002" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + if err := updateMigrationState(org, database, branch, migrationID, func(state *MigrationState) { + state.DBName = "customdb" + }); err != nil { + t.Fatalf("update state: %v", err) + } + + opts := VerifyOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + DestURI: "postgresql://u:p@localhost:5432/postgres?sslmode=disable", + InputPath: testFixture(t), + SQLitePath: "/nonexistent/staging.sqlite", + } + // Exercise DBName resolution before sqlite count fails. + dbName := opts.DBName + if dbName == "" && opts.MigrationID != "" { + if state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err == nil && state.DBName != "" { + dbName = state.DBName + } + } + if dbName != "customdb" { + t.Fatalf("resolved db_name = %q, want customdb", dbName) + } +} + +func TestResolveVerifySQLitePathDefaultsFromInput(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "verify003" + input := testFixture(t) + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: input, + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + _, sqlitePath, err := resolveVerifySQLitePath(VerifyOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + }) + if err != nil { + t.Fatalf("resolveVerifySQLitePath: %v", err) + } + want := DefaultSQLitePath(input) + if sqlitePath != want { + t.Fatalf("sqlite path = %q, want %q", sqlitePath, want) + } +} + +func TestResolveVerifySQLitePathFailsOnBadMigrationIDWithInput(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + _, _, err := resolveVerifySQLitePath(VerifyOptions{ + Org: "acme", + Database: "mydb", + Branch: "main", + MigrationID: "missing-migration", + InputPath: testFixture(t), + }) + requireMigrationErr(t, err, ErrCodeNotFound) +} diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go new file mode 100644 index 000000000..5ecd21fdf --- /dev/null +++ b/internal/postgres/postgres.go @@ -0,0 +1,336 @@ +// Package postgres provides PostgreSQL connection utilities. +package postgres + +import ( + "database/sql" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +type Config struct { + Host string + Port int + User string + Password string + Database string + SSLMode string + Options map[string]string +} + +// ParseConnectionURI supports both URI and keyword/value formats. +func ParseConnectionURI(uri string) (*Config, error) { + // Handle postgresql:// or postgres:// URIs + if strings.HasPrefix(uri, "postgresql://") || strings.HasPrefix(uri, "postgres://") { + return parseURIFormat(uri) + } + + // Handle keyword/value format (host=localhost port=5432 ...) + return parseKeyValueFormat(uri) +} + +func parseURIFormat(uri string) (*Config, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, fmt.Errorf("invalid connection URI: %w", err) + } + + cfg := &Config{ + Host: u.Hostname(), + Port: 5432, + Options: make(map[string]string), + } + + if portStr := u.Port(); portStr != "" { + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + cfg.Port = port + } + + if u.User != nil { + cfg.User = u.User.Username() + cfg.Password, _ = u.User.Password() + } + + // Database name from path without leading / + cfg.Database = strings.TrimPrefix(u.Path, "/") + + for key, values := range u.Query() { + if len(values) > 0 { + switch key { + case "sslmode": + cfg.SSLMode = values[0] + default: + cfg.Options[key] = values[0] + } + } + } + + if cfg.SSLMode == "" { + cfg.SSLMode = "require" + } + + return cfg, nil +} + +func parseKeyValueFormat(connStr string) (*Config, error) { + cfg := &Config{ + Port: 5432, + SSLMode: "require", + Options: make(map[string]string), + } + + for _, pair := range strings.Fields(connStr) { + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + continue + } + key := parts[0] + value := strings.Trim(parts[1], "'\"") + + switch key { + case "host": + cfg.Host = value + case "port": + port, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + cfg.Port = port + case "user": + cfg.User = value + case "password": + cfg.Password = value + case "dbname": + cfg.Database = value + case "sslmode": + cfg.SSLMode = value + default: + cfg.Options[key] = value + } + } + + return cfg, nil +} + +func BuildConnectionString(cfg *Config) string { + var parts []string + + if cfg.Host != "" { + parts = append(parts, fmt.Sprintf("host=%s", cfg.Host)) + } + if cfg.Port != 0 { + parts = append(parts, fmt.Sprintf("port=%d", cfg.Port)) + } + if cfg.User != "" { + parts = append(parts, fmt.Sprintf("user=%s", cfg.User)) + } + if cfg.Password != "" { + parts = append(parts, fmt.Sprintf("password=%s", quoteValue(cfg.Password))) + } + if cfg.Database != "" { + parts = append(parts, fmt.Sprintf("dbname=%s", cfg.Database)) + } + if cfg.SSLMode != "" { + parts = append(parts, fmt.Sprintf("sslmode=%s", cfg.SSLMode)) + } + + for key, value := range cfg.Options { + parts = append(parts, fmt.Sprintf("%s=%s", key, quoteValue(value))) + } + + return strings.Join(parts, " ") +} + +// BuildConnectionURI returns a postgresql:// URI suitable for pgloader. +func BuildConnectionURI(cfg *Config) string { + host := cfg.Host + if cfg.Port != 0 { + host = fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + } + + u := &url.URL{ + Scheme: "postgresql", + Host: host, + Path: "/" + cfg.Database, + } + + if cfg.User != "" { + if cfg.Password != "" { + u.User = url.UserPassword(cfg.User, cfg.Password) + } else { + u.User = url.User(cfg.User) + } + } + + q := url.Values{} + if cfg.SSLMode != "" { + q.Set("sslmode", cfg.SSLMode) + } + for key, value := range cfg.Options { + q.Set(key, value) + } + u.RawQuery = q.Encode() + + return u.String() +} + +func quoteValue(s string) string { + if strings.ContainsAny(s, " '\"\\") { + return "'" + strings.ReplaceAll(s, "'", "\\'") + "'" + } + return s +} + +// OpenConnection opens a PostgreSQL connection with sensible defaults. +func OpenConnection(connStr string) (*sql.DB, error) { + db, err := sql.Open("pgx", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open connection: %w", err) + } + + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(2) + db.SetConnMaxLifetime(5 * time.Minute) + + return db, nil +} + +// QuoteIdentifier escapes a PostgreSQL identifier. +func QuoteIdentifier(name string) string { + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` +} + +func RedactPassword(connStr string) string { + if strings.HasPrefix(connStr, "postgresql://") || strings.HasPrefix(connStr, "postgres://") { + u, err := url.Parse(connStr) + if err == nil && u.User != nil { + if _, hasPass := u.User.Password(); hasPass { + u.User = url.UserPassword(u.User.Username(), "****") + return u.String() + } + } + return connStr + } + + return redactKeywordPassword(connStr) +} + +func redactKeywordPassword(connStr string) string { + var parts []string + i := 0 + for i < len(connStr) { + for i < len(connStr) && (connStr[i] == ' ' || connStr[i] == '\t') { + i++ + } + if i >= len(connStr) { + break + } + keyStart := i + for i < len(connStr) && connStr[i] != '=' && connStr[i] != ' ' && connStr[i] != '\t' { + i++ + } + key := connStr[keyStart:i] + if i >= len(connStr) || connStr[i] != '=' { + parts = append(parts, strings.TrimSpace(connStr[keyStart:])) + break + } + i++ + if strings.EqualFold(key, "password") { + _, next := readPasswordConnValue(connStr, i) + i = next + parts = append(parts, key+"=****") + continue + } + val, next := readKeywordConnValue(connStr, i) + i = next + parts = append(parts, key+"="+val) + } + return strings.Join(parts, " ") +} + +var connParamKeywords = []string{ + "host", "hostaddr", "port", "user", "password", "dbname", "database", + "sslmode", "application_name", "connect_timeout", "options", + "fallback_application_name", "client_encoding", "target_session_attrs", + "replication", "gssencmode", "sslcert", "sslkey", "sslrootcert", + "requirepeer", "krbsrvname", "gsslib", "service", +} + +func readPasswordConnValue(connStr string, start int) (value string, next int) { + if start >= len(connStr) { + return "", start + } + if connStr[start] == '\'' || connStr[start] == '"' { + return readKeywordConnValue(connStr, start) + } + end := start + nextConnKeywordAssignIndex(connStr[start:]) + return strings.TrimSpace(connStr[start:end]), end +} + +func nextConnKeywordAssignIndex(s string) int { + if s == "" { + return 0 + } + lower := strings.ToLower(s) + best := len(s) + for _, kw := range connParamKeywords { + token := " " + kw + "=" + if idx := strings.Index(lower, token); idx >= 0 && idx < best { + best = idx + } + } + return best +} + +func readKeywordConnValue(connStr string, start int) (value string, next int) { + if start >= len(connStr) { + return "", start + } + switch connStr[start] { + case '\'': + var b strings.Builder + i := start + 1 + for i < len(connStr) { + if connStr[i] == '\'' { + if i+1 < len(connStr) && connStr[i+1] == '\'' { + b.WriteByte('\'') + i += 2 + continue + } + return b.String(), i + 1 + } + b.WriteByte(connStr[i]) + i++ + } + return b.String(), len(connStr) + case '"': + var b strings.Builder + i := start + 1 + for i < len(connStr) { + if connStr[i] == '"' { + if i+1 < len(connStr) && connStr[i+1] == '"' { + b.WriteByte('"') + i += 2 + continue + } + return b.String(), i + 1 + } + b.WriteByte(connStr[i]) + i++ + } + return b.String(), len(connStr) + default: + i := start + for i < len(connStr) && connStr[i] != ' ' && connStr[i] != '\t' { + i++ + } + return connStr[start:i], i + } +} diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go new file mode 100644 index 000000000..b5a6128b9 --- /dev/null +++ b/internal/postgres/postgres_test.go @@ -0,0 +1,220 @@ +package postgres + +import ( + "testing" +) + +func TestParseConnectionURI(t *testing.T) { + tests := []struct { + name string + uri string + want *Config + wantErr bool + }{ + { + name: "basic uri", + uri: "postgresql://user:pass@localhost:5432/mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + { + name: "uri with sslmode", + uri: "postgresql://user:pass@localhost:5432/mydb?sslmode=disable", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "disable", + Options: make(map[string]string), + }, + }, + { + name: "uri without password", + uri: "postgresql://user@localhost:5432/mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + { + name: "key-value format", + uri: "host=localhost port=5432 dbname=mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseConnectionURI(tt.uri) + if (err != nil) != tt.wantErr { + t.Errorf("ParseConnectionURI() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got.Host != tt.want.Host { + t.Errorf("Host = %v, want %v", got.Host, tt.want.Host) + } + if got.Port != tt.want.Port { + t.Errorf("Port = %v, want %v", got.Port, tt.want.Port) + } + if got.User != tt.want.User { + t.Errorf("User = %v, want %v", got.User, tt.want.User) + } + if got.Password != tt.want.Password { + t.Errorf("Password = %v, want %v", got.Password, tt.want.Password) + } + if got.Database != tt.want.Database { + t.Errorf("Database = %v, want %v", got.Database, tt.want.Database) + } + if got.SSLMode != tt.want.SSLMode { + t.Errorf("SSLMode = %v, want %v", got.SSLMode, tt.want.SSLMode) + } + }) + } +} + +func TestBuildConnectionString(t *testing.T) { + tests := []struct { + name string + cfg *Config + want string + }{ + { + name: "basic config", + cfg: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + want: "host=localhost port=5432 user=user password=pass dbname=mydb sslmode=require", + }, + { + name: "config without password", + cfg: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Database: "mydb", + SSLMode: "disable", + Options: make(map[string]string), + }, + want: "host=localhost port=5432 user=user dbname=mydb sslmode=disable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildConnectionString(tt.cfg) + if got != tt.want { + t.Errorf("BuildConnectionString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRedactPassword(t *testing.T) { + tests := []struct { + name string + connStr string + want string + }{ + { + name: "with password", + connStr: "host=localhost port=5432 user=user password=secret dbname=mydb", + want: "host=localhost port=5432 user=user password=**** dbname=mydb", + }, + { + name: "quoted password with spaces", + connStr: "host=localhost user=user password='my secret' dbname=mydb", + want: "host=localhost user=user password=**** dbname=mydb", + }, + { + name: "unquoted password with spaces", + connStr: "host=localhost user=user password=my secret dbname=mydb", + want: "host=localhost user=user password=**** dbname=mydb", + }, + { + name: "without password", + connStr: "host=localhost port=5432 user=user dbname=mydb", + want: "host=localhost port=5432 user=user dbname=mydb", + }, + { + name: "empty string", + connStr: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RedactPassword(tt.connStr) + if got != tt.want { + t.Errorf("RedactPassword() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestQuoteIdentifier(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + { + name: "simple identifier", + id: "mytable", + want: `"mytable"`, + }, + { + name: "identifier with quotes", + id: `table"name`, + want: `"table""name"`, + }, + { + name: "identifier with multiple quotes", + id: `my"table"name`, + want: `"my""table""name"`, + }, + { + name: "empty string", + id: "", + want: `""`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := QuoteIdentifier(tt.id) + if got != tt.want { + t.Errorf("QuoteIdentifier() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/postgres/psql.go b/internal/postgres/psql.go new file mode 100644 index 000000000..3149841be --- /dev/null +++ b/internal/postgres/psql.go @@ -0,0 +1,67 @@ +package postgres + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + exec "golang.org/x/sys/execabs" +) + +var psqlVersionRegex = regexp.MustCompile(`psql \(PostgreSQL\) (\d+)\.?(\d*)`) + +// FindPsqlPath locates a PostgreSQL psql client on PATH. +func FindPsqlPath() (string, error) { + for _, cmd := range []string{"psql-18", "psql-17", "psql-16", "psql-15", "psql"} { + path, err := exec.LookPath(cmd) + if err != nil { + continue + } + c := exec.Command(path, "--version") + out, err := c.Output() + if err != nil { + continue + } + if strings.Contains(string(out), "PostgreSQL") { + return path, nil + } + } + + return "", fmt.Errorf("couldn't find the 'psql' command-line tool required for PostgreSQL imports.\n" + + "To install, run: brew install postgresql@18") +} + +// CheckPsqlVersion verifies psql meets a minimum major version. +func CheckPsqlVersion(minMajor int) (major, minor int, err error) { + path, err := FindPsqlPath() + if err != nil { + return 0, 0, err + } + + c := exec.Command(path, "--version") + out, err := c.Output() + if err != nil { + return 0, 0, fmt.Errorf("failed to get psql version: %w", err) + } + + matches := psqlVersionRegex.FindStringSubmatch(string(out)) + if len(matches) < 2 { + return 0, 0, fmt.Errorf("could not parse psql version from: %s", string(out)) + } + + major, err = strconv.Atoi(matches[1]) + if err != nil { + return 0, 0, fmt.Errorf("could not parse psql major version: %w", err) + } + + if len(matches) > 2 && matches[2] != "" { + minor, _ = strconv.Atoi(matches[2]) + } + + if major < minMajor { + return major, minor, fmt.Errorf("psql version %d.%d is too old, minimum required is %d", major, minor, minMajor) + } + + return major, minor, nil +} diff --git a/internal/printer/printer.go b/internal/printer/printer.go index fdbc1749d..51d7b7e0c 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -125,23 +125,57 @@ func (p *Printer) out() io.Writer { // function needs to be called in a defer or when it's decided to stop the // spinner func (p *Printer) PrintProgress(message string) func() { + handle := p.StartProgress(message) + return handle.Stop +} + +// ProgressHandle is an updatable progress indicator (spinner on TTY). +type ProgressHandle struct { + update func(string) + stop func() +} + +// Update changes the progress message. +func (h *ProgressHandle) Update(message string) { + if h != nil && h.update != nil { + h.update(message) + } +} + +// Stop ends the progress indicator. +func (h *ProgressHandle) Stop() { + if h != nil && h.stop != nil { + h.stop() + } +} + +// StartProgress starts a spinner or line-based progress on w when not a TTY. +func (p *Printer) StartProgress(message string) *ProgressHandle { + return p.startProgressOn(p.out(), message) +} + +func (p *Printer) startProgressOn(w io.Writer, message string) *ProgressHandle { if !IsTTY { - fmt.Fprintln(p.out(), message) - return func() {} + fmt.Fprintln(w, message) + return &ProgressHandle{ + update: func(msg string) { fmt.Fprintln(w, msg) }, + } } - s := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(p.out())) + s := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(w)) s.Suffix = fmt.Sprintf(" %s", message) _ = s.Color("bold", "green") s.Start() - return func() { - s.Stop() - - // NOTE(fatih) the spinner library doesn't clear the line properly, - // hence remove it ourselves. This line should be removed once it's - // fixed in upstream. https://github.com/briandowns/spinner/pull/117 - fmt.Fprint(p.out(), "\r\033[2K") + return &ProgressHandle{ + update: func(msg string) { s.Suffix = fmt.Sprintf(" %s", msg) }, + stop: func() { + s.Stop() + // NOTE(fatih) the spinner library doesn't clear the line properly, + // hence remove it ourselves. This line should be removed once it's + // fixed in upstream. https://github.com/briandowns/spinner/pull/117 + fmt.Fprint(w, "\r\033[2K") + }, } }