Skip to content

Commit 628acbf

Browse files
authored
feat: Codegen-lsp v0 (#396)
1 parent 0400aa5 commit 628acbf

21 files changed

+1179
-19
lines changed

pyproject.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ types = [
115115
"types-requests>=2.32.0.20241016",
116116
"types-toml>=0.10.8.20240310",
117117
]
118+
lsp = ["pygls>=2.0.0a2", "lsprotocol==2024.0.0b1"]
118119
[tool.uv]
119120
cache-keys = [{ git = { commit = true, tags = true } }]
120121
dev-dependencies = [
@@ -149,11 +150,12 @@ dev-dependencies = [
149150
"isort>=5.13.2",
150151
"emoji>=2.14.0",
151152
"pytest-benchmark[histogram]>=5.1.0",
152-
"pytest-asyncio<1.0.0,>=0.21.1",
153+
"pytest-asyncio>=0.21.1,<1.0.0",
153154
"loguru>=0.7.3",
154155
"httpx<0.28.2,>=0.28.1",
155156
"jupyterlab>=4.3.5",
156157
"modal>=0.73.25",
158+
"pytest-lsp>=1.0.0b1",
157159
]
158160

159161

@@ -212,6 +214,8 @@ xfail_strict = true
212214
junit_duration_report = "call"
213215
junit_logging = "all"
214216
tmp_path_retention_policy = "failed"
217+
asyncio_mode = "auto"
218+
asyncio_default_fixture_loop_scope = "function"
215219
[build-system]
216220
requires = ["hatchling>=1.26.3", "hatch-vcs>=0.4.0", "setuptools-scm>=8.0.0"]
217221
build-backend = "hatchling.build"

src/codegen/extensions/lsp/completion.py

Whitespace-only changes.
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
3+
from lsprotocol.types import Position
4+
5+
from codegen.sdk.core.assignment import Assignment
6+
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
7+
from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute
8+
from codegen.sdk.core.expressions.expression import Expression
9+
from codegen.sdk.core.expressions.name import Name
10+
from codegen.sdk.core.interfaces.editable import Editable
11+
from codegen.sdk.core.interfaces.has_name import HasName
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def go_to_definition(node: Editable | None, uri: str, position: Position) -> Editable | None:
17+
if node is None or not isinstance(node, (Expression)):
18+
logger.warning(f"No node found at {uri}:{position}")
19+
return None
20+
if isinstance(node, Name) and isinstance(node.parent, ChainedAttribute) and node.parent.attribute == node:
21+
node = node.parent
22+
if isinstance(node.parent, FunctionCall) and node.parent.get_name() == node:
23+
node = node.parent
24+
logger.info(f"Resolving definition for {node}")
25+
if isinstance(node, FunctionCall):
26+
resolved = node.function_definition
27+
else:
28+
resolved = node.resolved_value
29+
if resolved is None:
30+
logger.warning(f"No resolved value found for {node.name} at {uri}:{position}")
31+
return None
32+
if isinstance(resolved, (HasName,)):
33+
resolved = resolved.get_name()
34+
if isinstance(resolved.parent, Assignment) and resolved.parent.value == resolved:
35+
resolved = resolved.parent.get_name()
36+
return resolved
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from lsprotocol.types import DocumentSymbol
2+
3+
from codegen.extensions.lsp.kind import get_kind
4+
from codegen.extensions.lsp.range import get_range
5+
from codegen.sdk.core.class_definition import Class
6+
from codegen.sdk.core.interfaces.editable import Editable
7+
from codegen.sdk.extensions.sort import sort_editables
8+
9+
10+
def get_document_symbol(node: Editable) -> DocumentSymbol:
11+
children = []
12+
nodes = []
13+
if isinstance(node, Class):
14+
nodes.extend(node.methods)
15+
nodes.extend(node.attributes)
16+
nodes.extend(node.nested_classes)
17+
nodes = sort_editables(nodes)
18+
for child in nodes:
19+
children.append(get_document_symbol(child))
20+
return DocumentSymbol(
21+
name=node.name,
22+
kind=get_kind(node),
23+
range=get_range(node),
24+
selection_range=get_range(node.get_name()),
25+
children=children,
26+
)

src/codegen/extensions/lsp/io.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
from pathlib import Path
3+
4+
from lsprotocol import types
5+
from lsprotocol.types import Position, Range, TextEdit
6+
from pygls.workspace import TextDocument, Workspace
7+
8+
from codegen.sdk.codebase.io.file_io import FileIO
9+
from codegen.sdk.codebase.io.io import IO
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class LSPIO(IO):
15+
base_io: FileIO
16+
workspace: Workspace
17+
changes: dict[str, TextEdit] = {}
18+
19+
def __init__(self, workspace: Workspace):
20+
self.workspace = workspace
21+
self.base_io = FileIO()
22+
23+
def _get_doc(self, path: Path) -> TextDocument | None:
24+
uri = path.as_uri()
25+
logger.info(f"Getting document for {uri}")
26+
return self.workspace.get_text_document(uri)
27+
28+
def read_bytes(self, path: Path) -> bytes:
29+
if self.changes.get(path.as_uri()):
30+
return self.changes[path.as_uri()].new_text.encode("utf-8")
31+
if doc := self._get_doc(path):
32+
return doc.source.encode("utf-8")
33+
return self.base_io.read_bytes(path)
34+
35+
def write_bytes(self, path: Path, content: bytes) -> None:
36+
logger.info(f"Writing bytes to {path}")
37+
start = Position(line=0, character=0)
38+
if doc := self._get_doc(path):
39+
end = Position(line=len(doc.source), character=len(doc.source))
40+
else:
41+
end = Position(line=0, character=0)
42+
self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8"))
43+
44+
def save_files(self, files: set[Path] | None = None) -> None:
45+
self.base_io.save_files(files)
46+
47+
def check_changes(self) -> None:
48+
self.base_io.check_changes()
49+
50+
def delete_file(self, path: Path) -> None:
51+
self.base_io.delete_file(path)
52+
53+
def file_exists(self, path: Path) -> bool:
54+
if doc := self._get_doc(path):
55+
try:
56+
doc.source
57+
except FileNotFoundError:
58+
return False
59+
return True
60+
return self.base_io.file_exists(path)
61+
62+
def untrack_file(self, path: Path) -> None:
63+
self.base_io.untrack_file(path)
64+
65+
def get_document_changes(self) -> list[types.TextDocumentEdit]:
66+
ret = []
67+
for uri, change in self.changes.items():
68+
id = types.OptionalVersionedTextDocumentIdentifier(uri=uri)
69+
ret.append(types.TextDocumentEdit(text_document=id, edits=[change]))
70+
self.changes = {}
71+
return ret

src/codegen/extensions/lsp/kind.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from lsprotocol.types import SymbolKind
2+
3+
from codegen.sdk.core.assignment import Assignment
4+
from codegen.sdk.core.class_definition import Class
5+
from codegen.sdk.core.file import File
6+
from codegen.sdk.core.function import Function
7+
from codegen.sdk.core.interface import Interface
8+
from codegen.sdk.core.interfaces.editable import Editable
9+
from codegen.sdk.core.statements.attribute import Attribute
10+
from codegen.sdk.typescript.namespace import TSNamespace
11+
12+
kinds = {
13+
File: SymbolKind.File,
14+
Class: SymbolKind.Class,
15+
Function: SymbolKind.Function,
16+
Assignment: SymbolKind.Variable,
17+
Interface: SymbolKind.Interface,
18+
TSNamespace: SymbolKind.Namespace,
19+
Attribute: SymbolKind.Variable,
20+
}
21+
22+
23+
def get_kind(node: Editable) -> SymbolKind:
24+
if isinstance(node, Function):
25+
if node.is_method:
26+
return SymbolKind.Method
27+
for kind in kinds:
28+
if isinstance(node, kind):
29+
return kinds[kind]
30+
msg = f"No kind found for {node}, {type(node)}"
31+
raise ValueError(msg)

src/codegen/extensions/lsp/lsp.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import logging
2+
3+
from lsprotocol import types
4+
5+
import codegen
6+
from codegen.extensions.lsp.definition import go_to_definition
7+
from codegen.extensions.lsp.document_symbol import get_document_symbol
8+
from codegen.extensions.lsp.protocol import CodegenLanguageServerProtocol
9+
from codegen.extensions.lsp.range import get_range
10+
from codegen.extensions.lsp.server import CodegenLanguageServer
11+
from codegen.extensions.lsp.utils import get_path
12+
from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite
13+
from codegen.sdk.core.file import SourceFile
14+
15+
version = getattr(codegen, "__version__", "v0.1")
16+
server = CodegenLanguageServer("codegen", version, protocol_cls=CodegenLanguageServerProtocol)
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
21+
def did_open(server: CodegenLanguageServer, params: types.DidOpenTextDocumentParams) -> None:
22+
"""Handle document open notification."""
23+
logger.info(f"Document opened: {params.text_document.uri}")
24+
# The document is automatically added to the workspace by pygls
25+
# We can perform any additional processing here if needed
26+
path = get_path(params.text_document.uri)
27+
file = server.codebase.get_file(str(path), optional=True)
28+
if not isinstance(file, SourceFile) and path.suffix in server.codebase.ctx.extensions:
29+
sync = DiffLite(change_type=ChangeType.Added, path=path)
30+
server.codebase.ctx.apply_diffs([sync])
31+
32+
33+
@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
34+
def did_change(server: CodegenLanguageServer, params: types.DidChangeTextDocumentParams) -> None:
35+
"""Handle document change notification."""
36+
logger.info(f"Document changed: {params.text_document.uri}")
37+
# The document is automatically updated in the workspace by pygls
38+
# We can perform any additional processing here if needed
39+
path = get_path(params.text_document.uri)
40+
sync = DiffLite(change_type=ChangeType.Modified, path=path)
41+
server.codebase.ctx.apply_diffs([sync])
42+
43+
44+
@server.feature(types.WORKSPACE_TEXT_DOCUMENT_CONTENT)
45+
def workspace_text_document_content(server: CodegenLanguageServer, params: types.TextDocumentContentParams) -> types.TextDocumentContentResult:
46+
"""Handle workspace text document content notification."""
47+
logger.debug(f"Workspace text document content: {params.uri}")
48+
path = get_path(params.uri)
49+
if not server.io.file_exists(path):
50+
logger.warning(f"File does not exist: {path}")
51+
return types.TextDocumentContentResult(
52+
text="",
53+
)
54+
content = server.io.read_text(path)
55+
return types.TextDocumentContentResult(
56+
text=content,
57+
)
58+
59+
60+
@server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
61+
def did_close(server: CodegenLanguageServer, params: types.DidCloseTextDocumentParams) -> None:
62+
"""Handle document close notification."""
63+
logger.info(f"Document closed: {params.text_document.uri}")
64+
# The document is automatically removed from the workspace by pygls
65+
# We can perform any additional cleanup here if needed
66+
67+
68+
@server.feature(
69+
types.TEXT_DOCUMENT_RENAME,
70+
)
71+
def rename(server: CodegenLanguageServer, params: types.RenameParams) -> types.RenameResult:
72+
symbol = server.get_symbol(params.text_document.uri, params.position)
73+
if symbol is None:
74+
logger.warning(f"No symbol found at {params.text_document.uri}:{params.position}")
75+
return
76+
logger.info(f"Renaming symbol {symbol.name} to {params.new_name}")
77+
symbol.rename(params.new_name)
78+
server.codebase.commit()
79+
return types.WorkspaceEdit(
80+
document_changes=server.io.get_document_changes(),
81+
)
82+
83+
84+
@server.feature(
85+
types.TEXT_DOCUMENT_DOCUMENT_SYMBOL,
86+
)
87+
def document_symbol(server: CodegenLanguageServer, params: types.DocumentSymbolParams) -> types.DocumentSymbolResult:
88+
file = server.get_file(params.text_document.uri)
89+
symbols = []
90+
for symbol in file.symbols:
91+
symbols.append(get_document_symbol(symbol))
92+
return symbols
93+
94+
95+
@server.feature(
96+
types.TEXT_DOCUMENT_DEFINITION,
97+
)
98+
def definition(server: CodegenLanguageServer, params: types.DefinitionParams):
99+
node = server.get_node_under_cursor(params.text_document.uri, params.position)
100+
resolved = go_to_definition(node, params.text_document.uri, params.position)
101+
return types.Location(
102+
uri=resolved.file.path.as_uri(),
103+
range=get_range(resolved),
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
logging.basicConfig(level=logging.INFO)
109+
server.start_io()
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import threading
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING
5+
6+
from lsprotocol.types import INITIALIZE, INITIALIZED, InitializedParams, InitializeParams, InitializeResult
7+
from pygls.protocol import LanguageServerProtocol, lsp_method
8+
9+
from codegen.extensions.lsp.io import LSPIO
10+
from codegen.extensions.lsp.utils import get_path
11+
from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags
12+
from codegen.sdk.core.codebase import Codebase
13+
14+
if TYPE_CHECKING:
15+
from codegen.extensions.lsp.server import CodegenLanguageServer
16+
17+
18+
class CodegenLanguageServerProtocol(LanguageServerProtocol):
19+
_server: "CodegenLanguageServer"
20+
21+
def _init_codebase(self, params: InitializeParams) -> None:
22+
if params.root_path:
23+
root = Path(params.root_path)
24+
elif params.root_uri:
25+
root = get_path(params.root_uri)
26+
else:
27+
root = os.getcwd()
28+
config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True))
29+
io = LSPIO(self.workspace)
30+
self._server.codebase = Codebase(repo_path=str(root), config=config, io=io)
31+
self._server.io = io
32+
33+
@lsp_method(INITIALIZE)
34+
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
35+
if params.root_path:
36+
root = Path(params.root_path)
37+
elif params.root_uri:
38+
root = get_path(params.root_uri)
39+
else:
40+
root = os.getcwd()
41+
config = CodebaseConfig(feature_flags=GSFeatureFlags(full_range_index=True))
42+
ret = super().lsp_initialize(params)
43+
44+
self._worker = threading.Thread(target=self._init_codebase, args=(params,))
45+
self._worker.start()
46+
return ret
47+
48+
@lsp_method(INITIALIZED)
49+
def lsp_initialized(self, params: InitializedParams) -> None:
50+
self._worker.join()
51+
super().lsp_initialized(params)

src/codegen/extensions/lsp/range.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tree_sitter
2+
from lsprotocol.types import Position, Range
3+
from pygls.workspace import TextDocument
4+
5+
from codegen.sdk.core.interfaces.editable import Editable
6+
7+
8+
def get_range(node: Editable) -> Range:
9+
start_point = node.start_point
10+
end_point = node.end_point
11+
for extended_node in node.extended_nodes:
12+
if extended_node.start_point.row < start_point.row:
13+
start_point = extended_node.start_point
14+
if extended_node.end_point.row > end_point.row:
15+
end_point = extended_node.end_point
16+
return Range(
17+
start=Position(line=start_point.row, character=start_point.column),
18+
end=Position(line=end_point.row, character=end_point.column),
19+
)
20+
21+
22+
def get_tree_sitter_range(range: Range, document: TextDocument) -> tree_sitter.Range:
23+
start_pos = tree_sitter.Point(row=range.start.line, column=range.start.character)
24+
end_pos = tree_sitter.Point(row=range.end.line, column=range.end.character)
25+
start_byte = document.offset_at_position(range.start)
26+
end_byte = document.offset_at_position(range.end)
27+
return tree_sitter.Range(
28+
start_point=start_pos,
29+
end_point=end_pos,
30+
start_byte=start_byte,
31+
end_byte=end_byte,
32+
)

0 commit comments

Comments
 (0)