diff --git a/pkg/plugin_packager/decoder/zip.go b/pkg/plugin_packager/decoder/zip.go index d991ccdc3..5d180d05f 100644 --- a/pkg/plugin_packager/decoder/zip.go +++ b/pkg/plugin_packager/decoder/zip.go @@ -7,9 +7,9 @@ import ( "fmt" "io" "io/fs" + "math" "os" "path" - "path/filepath" "strconv" "strings" @@ -18,6 +18,8 @@ import ( "github.com/langgenius/dify-plugin-daemon/pkg/utils/parser" ) +var errUnsafeZipPath = errors.New("unsafe path in plugin package") + type ZipPluginDecoder struct { PluginDecoder PluginDecoderHelper @@ -297,27 +299,46 @@ func (z *ZipPluginDecoder) UniqueIdentity() (plugin_entities.PluginUniqueIdentif func (z *ZipPluginDecoder) ExtractTo(dst string) error { // copy to working directory - if err := z.Walk(func(filename, dir string) error { - workingPath := path.Join(dst, dir) - // check if directory exists - if err := os.MkdirAll(workingPath, 0755); err != nil { + if z.reader == nil { + return z.err + } + + if err := func() error { + if err := os.MkdirAll(dst, 0755); err != nil { return err } - bytes, err := z.ReadFile(filepath.Join(dir, filename)) + root, err := os.OpenRoot(dst) if err != nil { return err } - - filename = filepath.Join(workingPath, filename) - - // copy file - if err := os.WriteFile(filename, bytes, 0644); err != nil { - return err + defer root.Close() + + for _, file := range z.reader.File { + entryPath, err := safeEntryPath(file.Name) + if err != nil { + return err + } + + if file.FileInfo().IsDir() { + if err := root.MkdirAll(entryPath, 0755); err != nil { + return err + } + continue + } + + // check if directory exists + if err := root.MkdirAll(path.Dir(entryPath), 0755); err != nil { + return err + } + + if err := extractZipFile(root, entryPath, file); err != nil { + return err + } } return nil - }); err != nil { + }(); err != nil { // if error, delete the working directory os.RemoveAll(dst) return errors.Join(fmt.Errorf("copy plugin to working directory error: %v", err), err) @@ -326,6 +347,57 @@ func (z *ZipPluginDecoder) ExtractTo(dst string) error { return nil } +func safeEntryPath(entryName string) (string, error) { + if entryName == "" || strings.Contains(entryName, `\`) { + return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName) + } + + for _, part := range strings.Split(entryName, "/") { + if part == ".." { + return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName) + } + } + + entryPath := path.Clean(entryName) + if entryPath == "." || path.IsAbs(entryPath) { + return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName) + } + + return entryPath, nil +} + +func extractZipFile(root *os.Root, entryPath string, file *zip.File) error { + reader, err := file.Open() + if err != nil { + return err + } + defer reader.Close() + + writer, err := root.OpenFile(entryPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer writer.Close() + + return copyZipFile(writer, reader, file.UncompressedSize64) +} + +func copyZipFile(writer io.Writer, reader io.Reader, uncompressedSize uint64) error { + if uncompressedSize > math.MaxInt64-1 { + return fmt.Errorf("zip entry is too large: %d bytes", uncompressedSize) + } + + limit := int64(uncompressedSize) + 1 + written, err := io.Copy(writer, io.LimitReader(reader, limit)) + if err != nil { + return err + } + if written > int64(uncompressedSize) { + return fmt.Errorf("zip entry exceeds declared uncompressed size: %d bytes", uncompressedSize) + } + return nil +} + func (z *ZipPluginDecoder) CheckAssetsValid() error { return z.PluginDecoderHelper.CheckAssetsValid(z) } diff --git a/pkg/plugin_packager/decoder/zip_extract_test.go b/pkg/plugin_packager/decoder/zip_extract_test.go new file mode 100644 index 000000000..770e4ca0b --- /dev/null +++ b/pkg/plugin_packager/decoder/zip_extract_test.go @@ -0,0 +1,127 @@ +package decoder + +import ( + "archive/zip" + "bytes" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func buildZipPlugin(t *testing.T, files map[string][]byte) []byte { + t.Helper() + + var buffer bytes.Buffer + zipWriter := zip.NewWriter(&buffer) + for name, content := range files { + if len(content) == 0 && strings.HasSuffix(name, "/") { + _, err := zipWriter.Create(name) + require.NoError(t, err) + continue + } + + writer, err := zipWriter.Create(name) + require.NoError(t, err) + _, err = writer.Write(content) + require.NoError(t, err) + } + require.NoError(t, zipWriter.Close()) + + return buffer.Bytes() +} + +func minimalPluginFiles(t *testing.T) map[string][]byte { + t.Helper() + + manifest, err := os.ReadFile(filepath.Join("..", "testdata", "manifest.yaml")) + require.NoError(t, err) + endpoint, err := os.ReadFile(filepath.Join("..", "testdata", "neko.yaml")) + require.NoError(t, err) + + return map[string][]byte{ + "manifest.yaml": manifest, + "neko.yaml": endpoint, + } +} + +func TestZipPluginDecoderExtractToRejectsParentPath(t *testing.T) { + files := minimalPluginFiles(t) + files["../escaped.txt"] = []byte("escaped") + + zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files)) + require.NoError(t, err) + + parent := t.TempDir() + dst := filepath.Join(parent, "plugin") + err = zipDecoder.ExtractTo(dst) + + require.Error(t, err) + assert.True(t, errors.Is(err, errUnsafeZipPath)) + assert.NoFileExists(t, filepath.Join(parent, "escaped.txt")) + assert.NoDirExists(t, dst) +} + +func TestZipPluginDecoderExtractToRejectsBackslashPath(t *testing.T) { + files := minimalPluginFiles(t) + files[`..\escaped.txt`] = []byte("escaped") + + zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files)) + require.NoError(t, err) + + parent := t.TempDir() + dst := filepath.Join(parent, "plugin") + err = zipDecoder.ExtractTo(dst) + + require.Error(t, err) + assert.True(t, errors.Is(err, errUnsafeZipPath)) + assert.NoFileExists(t, filepath.Join(parent, "escaped.txt")) + assert.NoDirExists(t, dst) +} + +func TestZipPluginDecoderExtractToAllowsNestedPath(t *testing.T) { + files := minimalPluginFiles(t) + files["nested/"] = nil + files["nested/file.txt"] = []byte("ok") + + zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files)) + require.NoError(t, err) + + dst := filepath.Join(t.TempDir(), "plugin") + require.NoError(t, zipDecoder.ExtractTo(dst)) + + extracted, err := os.ReadFile(filepath.Join(dst, "nested", "file.txt")) + require.NoError(t, err) + assert.Equal(t, []byte("ok"), extracted) +} + +func TestSafeEntryPathRejectsParentDirectoryEntry(t *testing.T) { + _, err := safeEntryPath("..") + + require.Error(t, err) + assert.True(t, errors.Is(err, errUnsafeZipPath)) +} + +func TestCopyZipFileRejectsContentBeyondDeclaredSize(t *testing.T) { + var out bytes.Buffer + + err := copyZipFile(&out, strings.NewReader("toolarge"), 3) + + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds declared uncompressed size") + assert.Equal(t, "tool", out.String()) +} + +func TestCopyZipFileAllowsDeclaredSize(t *testing.T) { + var out bytes.Buffer + + err := copyZipFile(&out, io.NopCloser(strings.NewReader("ok")), 2) + + require.NoError(t, err) + assert.Equal(t, "ok", out.String()) +}