-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstate.go
More file actions
134 lines (109 loc) · 3.04 KB
/
state.go
File metadata and controls
134 lines (109 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package main
import (
"errors"
"fmt"
lua "github.com/yuin/gopher-lua"
"sync"
)
type (
LuaResource interface {
IsLuaResource()
}
LuaRunner struct {
ls *lua.LState
cleanLs *lua.LState
lsLock sync.Mutex
skipOpenLibs bool
isLsClean bool
}
LuaLib string
LuaFunc string
LuaFunctionName string
LuaFunctionCode string
LuaFunctionRetNum uint
)
func (l LuaLib) IsLuaResource() {}
func (f LuaFunc) IsLuaResource() {}
func NewLuaRunner(skipOpenLibs bool, resources ...LuaResource) (*LuaRunner, error) {
lr := &LuaRunner{
ls: lua.NewState(lua.Options{SkipOpenLibs: skipOpenLibs}),
lsLock: sync.Mutex{},
skipOpenLibs: false,
isLsClean: true, // Start with a clean state
}
for _, r := range resources {
switch key := r.(type) {
case LuaLib:
lib, found := baseLuaLibs[key]
if !found {
return nil, fmt.Errorf("unsupported lua lib: %s", r)
}
if err := lr.loadLuaLib(string(key), lib); err != nil {
return nil, fmt.Errorf("cannot load lua.%s base lib", r)
}
case LuaFunc:
fn, found := baseLuaFns[key]
if !found {
return nil, fmt.Errorf("unsupported lua function: %s", r)
}
lr.ls.SetGlobal(string(key), lr.ls.NewFunction(fn))
}
}
lr.cleanLs = lr.ls
return lr, nil
}
func (lr *LuaRunner) Run(luaFnName LuaFunctionName, luaFn LuaFunctionCode, retNum LuaFunctionRetNum, args ...any) ([]any, error) {
lr.lsLock.Lock()
defer lr.lsLock.Unlock()
if !lr.isLsClean {
lr.refreshLState()
}
if err := lr.ls.DoString(string(luaFn)); err != nil {
var luaErr *lua.ApiError
if errors.As(err, &luaErr) && luaErr.Cause != nil {
return nil, fmt.Errorf("lua script error: %s", luaErr.Cause.Error())
}
return nil, fmt.Errorf("cannot load string into lua state: %w", err)
}
luaArgs, err := lr.convertArgsToLua(args)
if err != nil {
return nil, fmt.Errorf("runLua: %w", err)
}
if err = lr.ls.CallByParam(lua.P{Fn: lr.ls.GetGlobal(string(luaFnName)), NRet: int(retNum), Protect: true}, luaArgs...); err != nil {
return nil, &LuaError{err: fmt.Errorf("cannot run lua fn %s: %w", luaFnName, err)}
}
results := make([]any, retNum)
for i := 0; i < int(retNum); i++ {
// lua stack is 1-indexed and negative indices count from top of stack
lv := lr.ls.Get(-1 - i)
goVal, cErr := luaToGoType(lv)
if cErr != nil {
return nil, fmt.Errorf("lua - go conversion error: %w", err)
}
results[i] = goVal
}
return results, nil
}
func (lr *LuaRunner) convertArgsToLua(args []any) ([]lua.LValue, error) {
var luaArgs []lua.LValue
for _, arg := range args {
lVal, err := goToLuaType(lr.ls, arg)
if err != nil {
return nil, fmt.Errorf("cannot convert arg %T to lua type: %w", arg, err)
}
luaArgs = append(luaArgs, lVal)
}
return luaArgs, nil
}
func (lr *LuaRunner) refreshLState() {
lr.ls.Close()
lr.ls = lr.cleanLs
lr.isLsClean = true
}
func (lr *LuaRunner) loadLuaLib(name string, fn lua.LGFunction) error {
return lr.ls.CallByParam(lua.P{
Fn: lr.ls.NewFunction(fn),
NRet: 0,
Protect: true,
}, lua.LString(name))
}