Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 85 additions & 13 deletions pkg/plugin_packager/decoder/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"fmt"
"io"
"io/fs"
"math"
"os"
"path"
"path/filepath"
"strconv"
"strings"

Expand All @@ -18,6 +18,8 @@ import (
"github.com/langgenius/dify-plugin-daemon/pkg/utils/parser"
)

var errUnsafeZipPath = errors.New("unsafe path in plugin package")

type ZipPluginDecoder struct {
PluginDecoder
PluginDecoderHelper
Expand Down Expand Up @@ -297,27 +299,46 @@ func (z *ZipPluginDecoder) UniqueIdentity() (plugin_entities.PluginUniqueIdentif

func (z *ZipPluginDecoder) ExtractTo(dst string) error {
// copy to working directory
if err := z.Walk(func(filename, dir string) error {
workingPath := path.Join(dst, dir)
// check if directory exists
if err := os.MkdirAll(workingPath, 0755); err != nil {
if z.reader == nil {
return z.err
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

root, err := os.OpenRoot(dst)
  	if err != nil {
  		return err
  	}
  	defer root.Close()

  	for _, file := range z.reader.File {
  		entryPath, err := safeEntryPath(file.Name)
  		if err != nil {
  			return err
  		}

  		if file.FileInfo().IsDir() {
  			if err := root.MkdirAll(entryPath, 0755); err != nil {
  				return err
  			}
  			continue
  		}

  		if err := root.MkdirAll(path.Dir(entryPath), 0755); err != nil {
  			return err
  		}

  		if err := extractZipFile(root, entryPath, file); err != nil {
  			return err
  		}
  	}

  	return nil


if err := func() error {
if err := os.MkdirAll(dst, 0755); err != nil {
return err
}

bytes, err := z.ReadFile(filepath.Join(dir, filename))
root, err := os.OpenRoot(dst)
if err != nil {
return err
}

filename = filepath.Join(workingPath, filename)

// copy file
if err := os.WriteFile(filename, bytes, 0644); err != nil {
return err
defer root.Close()

for _, file := range z.reader.File {
entryPath, err := safeEntryPath(file.Name)
if err != nil {
return err
}

if file.FileInfo().IsDir() {
if err := root.MkdirAll(entryPath, 0755); err != nil {
return err
}
continue
}

// check if directory exists
if err := root.MkdirAll(path.Dir(entryPath), 0755); err != nil {
return err
}

if err := extractZipFile(root, entryPath, file); err != nil {
return err
}
}

return nil
}); err != nil {
}(); err != nil {
// if error, delete the working directory
os.RemoveAll(dst)
return errors.Join(fmt.Errorf("copy plugin to working directory error: %v", err), err)
Expand All @@ -326,6 +347,57 @@ func (z *ZipPluginDecoder) ExtractTo(dst string) error {
return nil
}

func safeEntryPath(entryName string) (string, error) {
if entryName == "" || strings.Contains(entryName, `\`) {
return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName)
}

for _, part := range strings.Split(entryName, "/") {
if part == ".." {
return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName)
}
}

entryPath := path.Clean(entryName)
if entryPath == "." || path.IsAbs(entryPath) {
return "", fmt.Errorf("%w: %q", errUnsafeZipPath, entryName)
}

return entryPath, nil
}

func extractZipFile(root *os.Root, entryPath string, file *zip.File) error {
reader, err := file.Open()
if err != nil {
return err
}
defer reader.Close()

writer, err := root.OpenFile(entryPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer writer.Close()

return copyZipFile(writer, reader, file.UncompressedSize64)
}

func copyZipFile(writer io.Writer, reader io.Reader, uncompressedSize uint64) error {
if uncompressedSize > math.MaxInt64-1 {
return fmt.Errorf("zip entry is too large: %d bytes", uncompressedSize)
}

limit := int64(uncompressedSize) + 1
written, err := io.Copy(writer, io.LimitReader(reader, limit))
if err != nil {
return err
}
if written > int64(uncompressedSize) {
return fmt.Errorf("zip entry exceeds declared uncompressed size: %d bytes", uncompressedSize)
}
return nil
}

func (z *ZipPluginDecoder) CheckAssetsValid() error {
return z.PluginDecoderHelper.CheckAssetsValid(z)
}
Expand Down
127 changes: 127 additions & 0 deletions pkg/plugin_packager/decoder/zip_extract_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package decoder

import (
"archive/zip"
"bytes"
"errors"
"io"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func buildZipPlugin(t *testing.T, files map[string][]byte) []byte {
t.Helper()

var buffer bytes.Buffer
zipWriter := zip.NewWriter(&buffer)
for name, content := range files {
if len(content) == 0 && strings.HasSuffix(name, "/") {
_, err := zipWriter.Create(name)
require.NoError(t, err)
continue
}

writer, err := zipWriter.Create(name)
require.NoError(t, err)
_, err = writer.Write(content)
require.NoError(t, err)
}
require.NoError(t, zipWriter.Close())

return buffer.Bytes()
}

func minimalPluginFiles(t *testing.T) map[string][]byte {
t.Helper()

manifest, err := os.ReadFile(filepath.Join("..", "testdata", "manifest.yaml"))
require.NoError(t, err)
endpoint, err := os.ReadFile(filepath.Join("..", "testdata", "neko.yaml"))
require.NoError(t, err)

return map[string][]byte{
"manifest.yaml": manifest,
"neko.yaml": endpoint,
}
}

func TestZipPluginDecoderExtractToRejectsParentPath(t *testing.T) {
files := minimalPluginFiles(t)
files["../escaped.txt"] = []byte("escaped")

zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files))
require.NoError(t, err)

parent := t.TempDir()
dst := filepath.Join(parent, "plugin")
err = zipDecoder.ExtractTo(dst)

require.Error(t, err)
assert.True(t, errors.Is(err, errUnsafeZipPath))
assert.NoFileExists(t, filepath.Join(parent, "escaped.txt"))
assert.NoDirExists(t, dst)
}

func TestZipPluginDecoderExtractToRejectsBackslashPath(t *testing.T) {
files := minimalPluginFiles(t)
files[`..\escaped.txt`] = []byte("escaped")

zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files))
require.NoError(t, err)

parent := t.TempDir()
dst := filepath.Join(parent, "plugin")
err = zipDecoder.ExtractTo(dst)

require.Error(t, err)
assert.True(t, errors.Is(err, errUnsafeZipPath))
assert.NoFileExists(t, filepath.Join(parent, "escaped.txt"))
assert.NoDirExists(t, dst)
}

func TestZipPluginDecoderExtractToAllowsNestedPath(t *testing.T) {
files := minimalPluginFiles(t)
files["nested/"] = nil
files["nested/file.txt"] = []byte("ok")

zipDecoder, err := NewZipPluginDecoder(buildZipPlugin(t, files))
require.NoError(t, err)

dst := filepath.Join(t.TempDir(), "plugin")
require.NoError(t, zipDecoder.ExtractTo(dst))

extracted, err := os.ReadFile(filepath.Join(dst, "nested", "file.txt"))
require.NoError(t, err)
assert.Equal(t, []byte("ok"), extracted)
}

func TestSafeEntryPathRejectsParentDirectoryEntry(t *testing.T) {
_, err := safeEntryPath("..")

require.Error(t, err)
assert.True(t, errors.Is(err, errUnsafeZipPath))
}

func TestCopyZipFileRejectsContentBeyondDeclaredSize(t *testing.T) {
var out bytes.Buffer

err := copyZipFile(&out, strings.NewReader("toolarge"), 3)

require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds declared uncompressed size")
assert.Equal(t, "tool", out.String())
}

func TestCopyZipFileAllowsDeclaredSize(t *testing.T) {
var out bytes.Buffer

err := copyZipFile(&out, io.NopCloser(strings.NewReader("ok")), 2)

require.NoError(t, err)
assert.Equal(t, "ok", out.String())
}