Skip to content

Commit 78ad481

Browse files
authored
feat: Update gpt4all, support multiple implementations in runtime (#472)
Signed-off-by: mudler <[email protected]>
1 parent 42d7538 commit 78ad481

12 files changed

+142
-29
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ release/
2525
# just in case
2626
.DS_Store
2727
.idea
28+
29+
# Generated during build
30+
backend-assets/

Makefile

+16-18
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ BINARY_NAME=local-ai
55

66
GOLLAMA_VERSION?=10caf37d8b73386708b4373975b8917e6b212c0e
77
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
8-
GPT4ALL_VERSION?=337c7fecacfa4ae6779046513ab090687a5b0ef6
8+
GPT4ALL_VERSION?=022f1cabe7dd2c911936b37510582f279069ba1e
99
GOGGMLTRANSFORMERS_VERSION?=13ccc22621bb21afecd38675a2b043498e2e756c
1010
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
1111
RWKV_VERSION?=ccb05c3e1c6efd098017d114dcb58ab3262b40b2
@@ -63,22 +63,13 @@ gpt4all:
6363
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
6464
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
6565
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
66-
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
67-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
68-
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} +
69-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
70-
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} +
71-
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
72-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
73-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
74-
@find ./gpt4all -type f -name "*.go" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
75-
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
76-
@find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
77-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
78-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
79-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
80-
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} +
81-
mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h
66+
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
67+
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
68+
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
69+
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
70+
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
71+
@find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
72+
8273

8374
## BERT embeddings
8475
go-bert:
@@ -124,6 +115,12 @@ bloomz/libbloomz.a: bloomz
124115
go-bert/libgobert.a: go-bert
125116
$(MAKE) -C go-bert libgobert.a
126117

118+
backend-assets/gpt4all: gpt4all/gpt4all-bindings/golang/libgpt4all.a
119+
mkdir -p backend-assets/gpt4all
120+
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.so backend-assets/gpt4all/ || true
121+
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dylib backend-assets/gpt4all/ || true
122+
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dll backend-assets/gpt4all/ || true
123+
127124
gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
128125
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ libgpt4all.a
129126

@@ -188,14 +185,15 @@ rebuild: ## Rebuilds the project
188185
$(MAKE) -C bloomz clean
189186
$(MAKE) build
190187

191-
prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
188+
prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
192189

193190
clean: ## Remove build related file
194191
rm -fr ./go-llama
195192
rm -rf ./gpt4all
196193
rm -rf ./go-gpt2
197194
rm -rf ./go-stable-diffusion
198195
rm -rf ./go-ggml-transformers
196+
rm -rf ./backend-assets
199197
rm -rf ./go-rwkv
200198
rm -rf ./go-bert
201199
rm -rf ./bloomz

api/api.go

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
6666
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
6767
}
6868
}
69+
70+
if options.assetsDestination != "" {
71+
if err := PrepareBackendAssets(options.backendAssets, options.assetsDestination); err != nil {
72+
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
73+
}
74+
}
75+
6976
// Default middleware config
7077
app.Use(recover.New())
7178

api/api_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ var _ = Describe("API test", func() {
257257
It("returns errors", func() {
258258
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
259259
Expect(err).To(HaveOccurred())
260-
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
260+
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:"))
261261
})
262262
It("transcribes audio", func() {
263263
if runtime.GOOS != "linux" {

api/backend_assets.go

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package api
2+
3+
import (
4+
"embed"
5+
"os"
6+
"path/filepath"
7+
8+
"github.com/go-skynet/LocalAI/pkg/assets"
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
func PrepareBackendAssets(backendAssets embed.FS, dst string) error {
13+
14+
// Extract files from the embedded FS
15+
err := assets.ExtractFiles(backendAssets, dst)
16+
if err != nil {
17+
return err
18+
}
19+
20+
// Set GPT4ALL libs where we extracted the files
21+
// https://github.com/nomic-ai/gpt4all/commit/27e80e1d10985490c9fd4214e4bf458cfcf70896
22+
gpt4alldir := filepath.Join(dst, "backend-assets", "gpt4all")
23+
os.Setenv("GPT4ALL_IMPLEMENTATIONS_PATH", gpt4alldir)
24+
log.Debug().Msgf("GPT4ALL_IMPLEMENTATIONS_PATH: %s", gpt4alldir)
25+
26+
return nil
27+
}

api/options.go

+16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package api
22

33
import (
44
"context"
5+
"embed"
56

67
model "github.com/go-skynet/LocalAI/pkg/model"
78
)
@@ -18,6 +19,9 @@ type Option struct {
1819
preloadJSONModels string
1920
preloadModelsFromPath string
2021
corsAllowOrigins string
22+
23+
backendAssets embed.FS
24+
assetsDestination string
2125
}
2226

2327
type AppOption func(*Option)
@@ -49,6 +53,18 @@ func WithCorsAllowOrigins(b string) AppOption {
4953
}
5054
}
5155

56+
func WithBackendAssetsOutput(out string) AppOption {
57+
return func(o *Option) {
58+
o.assetsDestination = out
59+
}
60+
}
61+
62+
func WithBackendAssets(f embed.FS) AppOption {
63+
return func(o *Option) {
64+
o.backendAssets = f
65+
}
66+
}
67+
5268
func WithContext(ctx context.Context) AppOption {
5369
return func(o *Option) {
5470
o.context = ctx

assets.go

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package main
2+
3+
import "embed"
4+
5+
//go:embed backend-assets/*
6+
var backendAssets embed.FS

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/hashicorp/go-multierror v1.1.1
1616
github.com/imdario/mergo v0.3.16
1717
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642
18-
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5
18+
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c
1919
github.com/onsi/ginkgo/v2 v2.9.7
2020
github.com/onsi/gomega v1.27.7
2121
github.com/otiai10/openaigo v1.1.0

go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81c
155155
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
156156
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 h1:99cF+V5wk7IInDAEM9HAlSHdLf/xoJR529Wr8lAG5KQ=
157157
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
158+
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c h1:KXYqUH6bdYbxnF67l8wayctaCZ4BQJQOsUyNke7HC0A=
159+
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
158160
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
159161
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
160162
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=

main.go

+8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ func main() {
8080
EnvVars: []string{"IMAGE_PATH"},
8181
Value: "",
8282
},
83+
&cli.StringFlag{
84+
Name: "backend-assets-path",
85+
DefaultText: "Path used to extract libraries that are required by some of the backends in runtime.",
86+
EnvVars: []string{"BACKEND_ASSETS_PATH"},
87+
Value: "/tmp/localai/backend_data",
88+
},
8389
&cli.IntFlag{
8490
Name: "context-size",
8591
DefaultText: "Default context size of the model",
@@ -124,6 +130,8 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
124130
api.WithCors(ctx.Bool("cors")),
125131
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
126132
api.WithThreads(ctx.Int("threads")),
133+
api.WithBackendAssets(backendAssets),
134+
api.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
127135
api.WithUploadLimitMB(ctx.Int("upload-limit")))
128136
if err != nil {
129137
return err

pkg/assets/extract.go

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package assets
2+
3+
import (
4+
"embed"
5+
"fmt"
6+
"io/fs"
7+
"os"
8+
"path/filepath"
9+
)
10+
11+
func ExtractFiles(content embed.FS, extractDir string) error {
12+
// Create the target directory if it doesn't exist
13+
err := os.MkdirAll(extractDir, 0755)
14+
if err != nil {
15+
return fmt.Errorf("failed to create directory: %v", err)
16+
}
17+
18+
// Walk through the embedded FS and extract files
19+
err = fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
20+
if err != nil {
21+
return err
22+
}
23+
24+
// Reconstruct the directory structure in the target directory
25+
targetFile := filepath.Join(extractDir, path)
26+
if d.IsDir() {
27+
// Create the directory in the target directory
28+
err := os.MkdirAll(targetFile, 0755)
29+
if err != nil {
30+
return fmt.Errorf("failed to create directory: %v", err)
31+
}
32+
return nil
33+
}
34+
35+
// Read the file from the embedded FS
36+
fileData, err := content.ReadFile(path)
37+
if err != nil {
38+
return fmt.Errorf("failed to read file: %v", err)
39+
}
40+
41+
// Create the file in the target directory
42+
err = os.WriteFile(targetFile, fileData, 0644)
43+
if err != nil {
44+
return fmt.Errorf("failed to write file: %v", err)
45+
}
46+
47+
return nil
48+
})
49+
50+
return err
51+
}

pkg/model/initializers.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ const (
3333
Gpt4AllLlamaBackend = "gpt4all-llama"
3434
Gpt4AllMptBackend = "gpt4all-mpt"
3535
Gpt4AllJBackend = "gpt4all-j"
36+
Gpt4All = "gpt4all"
3637
BertEmbeddingsBackend = "bert-embeddings"
3738
RwkvBackend = "rwkv"
3839
WhisperBackend = "whisper"
@@ -42,9 +43,7 @@ const (
4243

4344
var backends []string = []string{
4445
LlamaBackend,
45-
Gpt4AllLlamaBackend,
46-
Gpt4AllMptBackend,
47-
Gpt4AllJBackend,
46+
Gpt4All,
4847
RwkvBackend,
4948
GPTNeoXBackend,
5049
WhisperBackend,
@@ -153,12 +152,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
153152
return ml.LoadModel(modelFile, stableDiffusion)
154153
case StarcoderBackend:
155154
return ml.LoadModel(modelFile, starCoder)
156-
case Gpt4AllLlamaBackend:
157-
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType)))
158-
case Gpt4AllMptBackend:
159-
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType)))
160-
case Gpt4AllJBackend:
161-
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType)))
155+
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
156+
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads))))
162157
case BertEmbeddingsBackend:
163158
return ml.LoadModel(modelFile, bertEmbeddings)
164159
case RwkvBackend:

0 commit comments

Comments
 (0)