Skip to content

Commit 315d417

Browse files
Add support for relative symlinks; better testing for interior links
1 parent ce84e65 commit 315d417

File tree

2 files changed

+87
-31
lines changed

2 files changed

+87
-31
lines changed

memfs/memfs.go

+30-25
Original file line numberDiff line numberDiff line change
@@ -179,26 +179,17 @@ func (fs *MemFS) ReadDir(path string) ([]os.FileInfo, error) {
179179
}
180180

181181
func (fs *MemFS) fileInfo(path string) (parent *fileInfo, node *fileInfo, err error) {
182-
path = filepath.Clean(path)
183-
segments := vfs.SplitPath(path, PathSeparator)
182+
return fs.relativeFileInfo(fs.wd, path)
183+
}
184184

185-
// Shortcut for working directory and root
186-
if len(segments) == 1 {
187-
if segments[0] == "" {
188-
return nil, fs.root, nil
189-
} else if segments[0] == "." {
190-
return fs.wd.parent, fs.wd, nil
191-
}
192-
}
185+
func (fs *MemFS) relativeFileInfo(wd *fileInfo, path string) (parent *fileInfo, node *fileInfo, err error) {
186+
parent, segments := fs.dirSegments(wd, path)
193187

194-
// Determine root to traverse
195-
parent = fs.root
196-
if segments[0] == "." {
197-
parent = fs.wd
188+
// Shortcut for working directory and root
189+
if len(segments) == 0 {
190+
return parent.parent, parent, nil
198191
}
199-
segments = segments[1:]
200192

201-
// Further directories
202193
for _, seg := range segments[:len(segments)-1] {
203194

204195
if parent.childs == nil {
@@ -211,10 +202,16 @@ func (fs *MemFS) fileInfo(path string) (parent *fileInfo, node *fileInfo, err er
211202
if entry.dir {
212203
parent = entry
213204
} else if entry.mode & os.ModeSymlink != 0 {
214-
_, parent, err = fs.fileInfo(string(*entry.buf))
205+
// Look up interior symlink
206+
_, parent, err = fs.relativeFileInfo(parent, string(*entry.buf))
215207
if err != nil {
216208
return nil, nil, err
217209
}
210+
// Symlink was not to a directory
211+
if parent == nil {
212+
return nil, nil, vfs.ErrNotDirectory
213+
}
214+
218215
} else {
219216
return nil, nil, os.ErrNotExist
220217
}
@@ -223,6 +220,9 @@ func (fs *MemFS) fileInfo(path string) (parent *fileInfo, node *fileInfo, err er
223220
lastSeg := segments[len(segments)-1]
224221
if parent.childs != nil {
225222
if node, ok := parent.childs[lastSeg]; ok {
223+
if node.mode & os.ModeSymlink != 0 {
224+
return fs.relativeFileInfo(parent, string(*node.buf))
225+
}
226226
return parent, node, nil
227227
}
228228
} else {
@@ -232,6 +232,19 @@ func (fs *MemFS) fileInfo(path string) (parent *fileInfo, node *fileInfo, err er
232232
return parent, nil, nil
233233
}
234234

235+
func (fs *MemFS) dirSegments(wd *fileInfo, path string) (parent *fileInfo, segments []string) {
236+
path = filepath.Clean(path)
237+
segments = vfs.SplitPath(path, PathSeparator)
238+
239+
// Determine root to traverse
240+
parent = fs.root
241+
if segments[0] == "." {
242+
parent = wd
243+
}
244+
segments = segments[1:]
245+
return parent, segments
246+
}
247+
235248
func hasFlag(flag int, flags int) bool {
236249
return flags&flag == flag
237250
}
@@ -243,11 +256,6 @@ func (fs *MemFS) OpenFile(name string, flag int, perm os.FileMode) (vfs.File, er
243256
fs.lock.Lock()
244257
defer fs.lock.Unlock()
245258

246-
return fs.openFile(name, flag, perm)
247-
}
248-
249-
func (fs *MemFS) openFile(name string, flag int, perm os.FileMode) (vfs.File, error) {
250-
251259
name = filepath.Clean(name)
252260
base := filepath.Base(name)
253261
fiParent, fiNode, err := fs.fileInfo(name)
@@ -275,9 +283,6 @@ func (fs *MemFS) openFile(name string, flag int, perm os.FileMode) (vfs.File, er
275283
if fiNode.dir {
276284
return nil, &os.PathError{"open", name, ErrIsDirectory}
277285
}
278-
if fiNode.mode & os.ModeSymlink != 0 {
279-
return fs.openFile(string(*fiNode.buf), flag, perm)
280-
}
281286
}
282287

283288
if !hasFlag(os.O_RDONLY, flag) {

memfs/memfs_test.go

+57-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package memfs
33
import (
44
"io/ioutil"
55
"os"
6+
"strings"
67
"testing"
78
"time"
89

@@ -139,10 +140,6 @@ func TestSymlink(t *testing.T) {
139140
if err != nil {
140141
t.Fatal("Symlink failed:", err)
141142
}
142-
_, node, err := fs.fileInfo("/tmp/cup")
143-
if string(*node.buf) != "/tmp/teacup" {
144-
t.Fatal("Wrong symlink contents in buf:", string(*node.buf))
145-
}
146143
fluid, err := vfs.ReadFile(fs, "/tmp/cup")
147144
if err != nil {
148145
t.Fatal("Failed to read from /tmp/cup:", err)
@@ -158,9 +155,11 @@ func TestDirectorySymlink(t *testing.T) {
158155
if err := vfs.MkdirAll(fs, "/foo/a/b", 0755); err != nil {
159156
t.Fatal("Unable mkdir /foo/a/b:", err)
160157
}
158+
161159
if err := vfs.WriteFile(fs, "/foo/a/b/c", []byte("I can \"c\" clearly now"), 0644); err != nil {
162160
t.Fatal("Unable to write /foo/a/b/c:", err)
163161
}
162+
164163
if err := fs.Symlink("/foo/a/b", "/foo/also_b"); err != nil {
165164
t.Fatal("Unable to symlink /foo/also_b -> /foo/a/b:", err)
166165
}
@@ -174,8 +173,60 @@ func TestDirectorySymlink(t *testing.T) {
174173
}
175174
}
176175

177-
// TODO: relative symlinks
178-
// TODO: overwrite symlinks
176+
func TestMultipleAndRelativeSymlinks(t *testing.T) {
177+
fs := Create()
178+
if err := vfs.MkdirAll(fs, "a/real_b/real_c", 0755); err != nil {
179+
t.Fatal("Unable mkdir a/real_b/real_c:", err)
180+
}
181+
182+
for _, fsEntry := range []struct {
183+
name, link, content string
184+
}{
185+
{name: "a/b", link: "real_b"},
186+
{name: "a/b/c", link: "real_c"},
187+
{name: "a/b/c/real_d", content: "Lah dee dah"},
188+
{name: "a/b/c/d", link: "real_d"},
189+
{name: "a/d", link: "b/c/d"},
190+
} {
191+
if fsEntry.link != "" {
192+
if err := fs.Symlink(fsEntry.link, fsEntry.name); err != nil {
193+
t.Fatalf("Unable to symlink %s -> %s: %v", fsEntry.name, fsEntry.link, err)
194+
}
195+
} else if fsEntry.content != "" {
196+
if err := vfs.WriteFile(fs, fsEntry.name, []byte(fsEntry.content), 0644); err != nil {
197+
t.Fatalf("Unable to write %s: %v", fsEntry.name, err)
198+
}
199+
}
200+
}
201+
202+
for _, fn := range []string{
203+
"a/b/c/d",
204+
"a/d",
205+
} {
206+
contents, err := vfs.ReadFile(fs, fn)
207+
if err != nil {
208+
t.Fatalf("Unable to read %s: %v", fn, err)
209+
}
210+
if string(contents) != "Lah dee dah" {
211+
t.Fatalf("Unexpected contents read from %s: %v", fn, err)
212+
}
213+
}
214+
}
215+
216+
func TestSymlinkIsNotADirectory(t *testing.T) {
217+
fs := Create()
218+
if err := vfs.MkdirAll(fs, "a/real_b/real_c", 0755); err != nil {
219+
t.Fatal("Unable mkdir a/real_b/real_c:", err)
220+
}
221+
if err := fs.Symlink("broken", "a/b"); err != nil {
222+
t.Fatal("Unable to symlink a/b -> broken:", err)
223+
}
224+
if err := vfs.WriteFile(fs, "a/b/c", []byte("Whatever"), 0644); !strings.Contains(err.Error(), vfs.ErrNotDirectory.Error()) {
225+
t.Fatal("Expected an error when writing a/b/c:", err)
226+
}
227+
}
228+
229+
// TODO: overwrite/remove symlinks
179230

180231
func TestReadDir(t *testing.T) {
181232
fs := Create()

0 commit comments

Comments
 (0)