Skip to content

Commit d944c24

Browse files
njskalskiAndrzej J Skalski
andauthored
fix(langserver): support go-to-definition for plugin-defined rules (#3491)
This is a follow up MR from #3485 . As suggested by @toastwaffle I split the MR into two smaller pieces. This one is about go-to-definition in plugin rules. I applied every single requested change. Furthermore, I removed all formatting changes my IDE did, to keep the patch short. --------- Co-authored-by: Andrzej J Skalski <gitstuff@s5i.ch>
1 parent e128a94 commit d944c24

9 files changed

Lines changed: 147 additions & 23 deletions

File tree

src/cmap/cerrmap.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,14 @@ func (m *ErrMap[K, V]) GetOrSet(key K, f func() (V, error)) (V, error) {
7878
}
7979
return v.Val, v.Err
8080
}
81+
82+
// Range calls f for each key-value pair in the map.
83+
// No particular consistency guarantees are made during iteration.
84+
func (m *ErrMap[K, V]) Range(f func(key K, val V)) {
85+
m.m.Range(func(key K, val errV[V]) {
86+
if val.Err != nil {
87+
return // skip errors
88+
}
89+
f(key, val.Val)
90+
})
91+
}

src/cmap/cmap.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ func (m *Map[K, V]) Values() []V {
9494
return ret
9595
}
9696

97+
// Range calls f for each key-value pair in the map.
98+
// No particular consistency guarantees are made during iteration.
99+
func (m *Map[K, V]) Range(f func(key K, val V)) {
100+
for i := 0; i < len(m.shards); i++ {
101+
m.shards[i].Range(f)
102+
}
103+
}
104+
97105
// An awaitableValue represents a value in the map & an awaitable channel for it to exist.
98106
type awaitableValue[V any] struct {
99107
Val V
@@ -195,3 +203,14 @@ func (s *shard[K, V]) Contains(key K) bool {
195203
_, ok := s.m[key]
196204
return ok
197205
}
206+
207+
// Range calls f for each key-value pair in this shard.
208+
func (s *shard[K, V]) Range(f func(key K, val V)) {
209+
s.l.RLock()
210+
defer s.l.RUnlock()
211+
for k, v := range s.m {
212+
if v.Wait == nil { // Only include completed values
213+
f(k, v.Val)
214+
}
215+
}
216+
}

src/parse/asp/parser.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,24 @@ func (p *Parser) optimiseBuiltinCalls(stmts []*Statement) {
257257
}
258258
}
259259

260+
// AllFunctionsByFile returns all function definitions grouped by filename.
261+
// This includes functions from builtins, plugins, and subincludes.
262+
// It iterates over the ASTs stored by the interpreter.
263+
func (p *Parser) AllFunctionsByFile() map[string][]*Statement {
264+
if p.interpreter == nil || p.interpreter.asts == nil {
265+
return nil
266+
}
267+
result := make(map[string][]*Statement)
268+
p.interpreter.asts.Range(func(filename string, stmts []*Statement) {
269+
for _, stmt := range stmts {
270+
if stmt.FuncDef != nil {
271+
result[filename] = append(result[filename], stmt)
272+
}
273+
}
274+
})
275+
return result
276+
}
277+
260278
// whitelistedKwargs returns true if the given built-in function name is allowed to
261279
// be called as non-kwargs.
262280
// TODO(peterebden): Come up with a syntax that exposes this directly in the file.

src/parse/init.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,19 @@ func InitParser(state *core.BuildState) *core.BuildState {
2525
return state
2626
}
2727

28+
// GetAspParser returns the underlying asp.Parser from the state's parser.
29+
// This is useful for tools like the language server that need direct access to AST information.
30+
// Returns nil if the state's parser is not set or is not an aspParser.
31+
func GetAspParser(state *core.BuildState) *asp.Parser {
32+
if state.Parser == nil {
33+
return nil
34+
}
35+
if ap, ok := state.Parser.(*aspParser); ok {
36+
return ap.parser
37+
}
38+
return nil
39+
}
40+
2841
// aspParser implements the core.Parser interface around our parser package.
2942
type aspParser struct {
3043
parser *asp.Parser

tools/build_langserver/lsp/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ go_library(
1717
"//rules",
1818
"//src/core",
1919
"//src/fs",
20+
"//src/parse",
2021
"//src/parse/asp",
2122
"//src/plz",
2223
"//tools/build_langserver/lsp/astutils",

tools/build_langserver/lsp/completion.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ func (h *Handler) completeString(doc *doc, s string, line, col int) (*lsp.Comple
109109
// completeIdent completes an arbitrary identifier
110110
func (h *Handler) completeIdent(doc *doc, s string, line, col int) (*lsp.CompletionList, error) {
111111
list := &lsp.CompletionList{}
112-
for name, f := range h.builtins {
113-
if strings.HasPrefix(name, s) {
112+
for name, builtins := range h.builtins {
113+
if strings.HasPrefix(name, s) && len(builtins) > 0 {
114114
item := completionItem(name, "", line, col)
115-
item.Documentation = f.Stmt.FuncDef.Docstring
115+
item.Documentation = builtins[0].Stmt.FuncDef.Docstring
116116
item.Kind = lsp.CIKFunction
117117
list.Items = append(list.Items, item)
118118
}

tools/build_langserver/lsp/definition.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,40 @@ func (h *Handler) definition(params *lsp.TextDocumentPositionParams) ([]lsp.Loca
1818
ast := h.parseIfNeeded(doc)
1919
f := doc.AspFile()
2020

21-
var locs []lsp.Location
21+
locs := []lsp.Location{}
2222
pos := aspPos(params.Position)
2323
asp.WalkAST(ast, func(expr *asp.Expression) bool {
24-
if !asp.WithinRange(pos, f.Pos(expr.Pos), f.Pos(expr.EndPos)) {
24+
exprStart := f.Pos(expr.Pos)
25+
exprEnd := f.Pos(expr.EndPos)
26+
if !asp.WithinRange(pos, exprStart, exprEnd) {
2527
return false
2628
}
27-
2829
if expr.Val.Ident != nil {
2930
if loc := h.findGlobal(expr.Val.Ident.Name); loc.URI != "" {
3031
locs = append(locs, loc)
3132
}
3233
return false
3334
}
34-
3535
if expr.Val.String != "" {
3636
label := astutils.TrimStrLit(expr.Val.String)
3737
if loc := h.findLabel(doc.PkgName, label); loc.URI != "" {
3838
locs = append(locs, loc)
3939
}
4040
return false
4141
}
42-
4342
return true
4443
})
45-
// It might also be a statement.
44+
// It might also be a statement (e.g. a function call like go_library(...))
4645
asp.WalkAST(ast, func(stmt *asp.Statement) bool {
4746
if stmt.Ident != nil {
48-
endPos := f.Pos(stmt.Pos)
47+
stmtStart := f.Pos(stmt.Pos)
48+
endPos := stmtStart
4949
// TODO(jpoole): The AST should probably just have this information
5050
endPos.Column += len(stmt.Ident.Name)
5151

52-
if !asp.WithinRange(pos, f.Pos(stmt.Pos), endPos) {
53-
return false
52+
if !asp.WithinRange(pos, stmtStart, endPos) {
53+
return true // continue to other statements
5454
}
55-
5655
if loc := h.findGlobal(stmt.Ident.Name); loc.URI != "" {
5756
locs = append(locs, loc)
5857
}
@@ -78,6 +77,9 @@ func (h *Handler) findLabel(currentPath, label string) lsp.Location {
7877
}
7978

8079
pkg := h.state.Graph.PackageByLabel(l)
80+
if pkg == nil {
81+
return lsp.Location{}
82+
}
8183
uri := lsp.DocumentURI("file://" + filepath.Join(h.root, pkg.Filename))
8284
loc := lsp.Location{URI: uri}
8385
doc, err := h.maybeOpenDoc(uri)
@@ -137,9 +139,18 @@ func findName(args []asp.CallArgument) string {
137139

138140
// findGlobal returns the location of a global of the given name.
139141
func (h *Handler) findGlobal(name string) lsp.Location {
140-
if f, present := h.builtins[name]; present {
142+
h.mutex.Lock()
143+
builtins := h.builtins[name]
144+
h.mutex.Unlock()
145+
if len(builtins) > 0 {
146+
f := builtins[0]
147+
filename := f.Pos.Filename
148+
// Make path absolute if it's relative
149+
if !filepath.IsAbs(filename) {
150+
filename = filepath.Join(h.root, filename)
151+
}
141152
return lsp.Location{
142-
URI: lsp.DocumentURI("file://" + f.Pos.Filename),
153+
URI: lsp.DocumentURI("file://" + filename),
143154
Range: rng(f.Pos, f.EndPos),
144155
}
145156
}

tools/build_langserver/lsp/lsp.go

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/thought-machine/please/rules"
2121
"github.com/thought-machine/please/src/core"
2222
"github.com/thought-machine/please/src/fs"
23+
"github.com/thought-machine/please/src/parse"
2324
"github.com/thought-machine/please/src/parse/asp"
2425
"github.com/thought-machine/please/src/plz"
2526
)
@@ -33,7 +34,7 @@ type Handler struct {
3334
mutex sync.Mutex // guards docs
3435
state *core.BuildState
3536
parser *asp.Parser
36-
builtins map[string]builtin
37+
builtins map[string][]builtin
3738
pkgs *pkg
3839
root string
3940
}
@@ -55,7 +56,7 @@ func NewHandler() *Handler {
5556
return &Handler{
5657
docs: map[string]*doc{},
5758
pkgs: &pkg{},
58-
builtins: map[string]builtin{},
59+
builtins: map[string][]builtin{},
5960
}
6061
}
6162

@@ -195,13 +196,35 @@ func (h *Handler) initialize(params *lsp.InitializeParams) (*lsp.InitializeResul
195196
}
196197
h.state = core.NewBuildState(config)
197198
h.state.NeedBuild = false
198-
// We need an unwrapped parser instance as well for raw access.
199-
h.parser = asp.NewParser(h.state)
199+
// Initialize the parser on state first, so that plz.RunHost uses the same parser.
200+
// This ensures plugin subincludes are stored in the same AST cache we use.
201+
parse.InitParser(h.state)
202+
h.parser = parse.GetAspParser(h.state)
203+
if h.parser == nil {
204+
return nil, fmt.Errorf("failed to get asp parser from state")
205+
}
200206
// Parse everything in the repo up front.
201207
// This is a lot easier than trying to do clever partial parses later on, although
202208
// eventually we may want that if we start dealing with truly large repos.
203209
go func() {
210+
// Start a goroutine to periodically load parser functions as they become available.
211+
// This allows go-to-definition to work progressively while the full parse runs.
212+
done := make(chan struct{})
213+
go func() {
214+
ticker := time.NewTicker(2 * time.Second)
215+
defer ticker.Stop()
216+
for {
217+
select {
218+
case <-done:
219+
h.loadParserFunctions()
220+
return
221+
case <-ticker.C:
222+
h.loadParserFunctions()
223+
}
224+
}
225+
}()
204226
plz.RunHost(core.WholeGraph, h.state)
227+
close(done)
205228
log.Debug("initial parse complete")
206229
h.buildPackageTree()
207230
log.Debug("built completion package tree")
@@ -256,18 +279,46 @@ func (h *Handler) loadBuiltins() error {
256279
f := asp.NewFile(dest, data)
257280
for _, stmt := range stmts {
258281
if stmt.FuncDef != nil {
259-
h.builtins[stmt.FuncDef.Name] = builtin{
282+
h.builtins[stmt.FuncDef.Name] = append(h.builtins[stmt.FuncDef.Name], builtin{
260283
Stmt: stmt,
261284
Pos: f.Pos(stmt.Pos),
262285
EndPos: f.Pos(stmt.EndPos),
263-
}
286+
})
264287
}
265288
}
266289
}
267290
log.Debug("loaded builtin function information")
268291
return nil
269292
}
270293

294+
// loadParserFunctions loads function definitions from the parser's ASTs.
295+
// This includes plugin-defined functions like go_library, python_library, etc.
296+
func (h *Handler) loadParserFunctions() {
297+
funcsByFile := h.parser.AllFunctionsByFile()
298+
if funcsByFile == nil {
299+
return
300+
}
301+
h.mutex.Lock()
302+
defer h.mutex.Unlock()
303+
for filename, stmts := range funcsByFile {
304+
// Read the file to create a File object for position conversion
305+
data, err := os.ReadFile(filename)
306+
if err != nil {
307+
log.Warning("failed to read file %s: %v", filename, err)
308+
continue
309+
}
310+
file := asp.NewFile(filename, data)
311+
for _, stmt := range stmts {
312+
name := stmt.FuncDef.Name
313+
h.builtins[name] = append(h.builtins[name], builtin{
314+
Stmt: stmt,
315+
Pos: file.Pos(stmt.Pos),
316+
EndPos: file.Pos(stmt.EndPos),
317+
})
318+
}
319+
}
320+
}
321+
271322
// fromURI converts a DocumentURI to a path.
272323
func fromURI(uri lsp.DocumentURI) string {
273324
if !strings.HasPrefix(string(uri), "file://") {

tools/build_langserver/lsp/lsp_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ func TestCompletionFunction(t *testing.T) {
458458
Kind: lsp.CIKFunction,
459459
InsertTextFormat: lsp.ITFPlainText,
460460
TextEdit: textEdit("plugin_repo", 0, 4, 0),
461-
Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring,
461+
Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring,
462462
}},
463463
}, completions)
464464
}
@@ -492,7 +492,7 @@ func TestCompletionPartialFunction(t *testing.T) {
492492
Kind: lsp.CIKFunction,
493493
InsertTextFormat: lsp.ITFPlainText,
494494
TextEdit: textEdit("plugin_repo", 0, 9, 0),
495-
Documentation: h.builtins["plugin_repo"].Stmt.FuncDef.Docstring,
495+
Documentation: h.builtins["plugin_repo"][0].Stmt.FuncDef.Docstring,
496496
}},
497497
}, completions)
498498
}

0 commit comments

Comments
 (0)