|
@@ -57,6 +57,9 @@ import (
|
|
|
"io/ioutil"
|
|
|
"log"
|
|
|
"os"
|
|
|
+ "path/filepath"
|
|
|
+ "runtime"
|
|
|
+ "sort"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"text/template"
|
|
@@ -65,6 +68,7 @@ import (
|
|
|
var (
|
|
|
filename = flag.String("output", "", "output file name (standard output if omitted)")
|
|
|
printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
|
|
|
+ systemDLL = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
|
|
|
)
|
|
|
|
|
|
func trim(s string) string {
|
|
@@ -277,7 +281,7 @@ func (r *Rets) SetReturnValuesCode() string {
|
|
|
func (r *Rets) useLongHandleErrorCode(retvar string) string {
|
|
|
const code = `if %s {
|
|
|
if e1 != 0 {
|
|
|
- err = error(e1)
|
|
|
+ err = errnoErr(e1)
|
|
|
} else {
|
|
|
err = %sEINVAL
|
|
|
}
|
|
@@ -607,7 +611,6 @@ func (f *Fn) IsNotDuplicate() bool {
|
|
|
uniqDllFuncName[funcName] = true
|
|
|
return true
|
|
|
}
|
|
|
-
|
|
|
return false
|
|
|
}
|
|
|
|
|
@@ -621,8 +624,20 @@ func (f *Fn) HelperName() string {
|
|
|
|
|
|
// Source files and functions.
|
|
|
type Source struct {
|
|
|
- Funcs []*Fn
|
|
|
- Files []string
|
|
|
+ Funcs []*Fn
|
|
|
+ Files []string
|
|
|
+ StdLibImports []string
|
|
|
+ ExternalImports []string
|
|
|
+}
|
|
|
+
|
|
|
+func (src *Source) Import(pkg string) {
|
|
|
+ src.StdLibImports = append(src.StdLibImports, pkg)
|
|
|
+ sort.Strings(src.StdLibImports)
|
|
|
+}
|
|
|
+
|
|
|
+func (src *Source) ExternalImport(pkg string) {
|
|
|
+ src.ExternalImports = append(src.ExternalImports, pkg)
|
|
|
+ sort.Strings(src.ExternalImports)
|
|
|
}
|
|
|
|
|
|
// ParseFiles parses files listed in fs and extracts all syscall
|
|
@@ -632,6 +647,10 @@ func ParseFiles(fs []string) (*Source, error) {
|
|
|
src := &Source{
|
|
|
Funcs: make([]*Fn, 0),
|
|
|
Files: make([]string, 0),
|
|
|
+ StdLibImports: []string{
|
|
|
+ "unsafe",
|
|
|
+ },
|
|
|
+ ExternalImports: make([]string, 0),
|
|
|
}
|
|
|
for _, file := range fs {
|
|
|
if err := src.ParseFile(file); err != nil {
|
|
@@ -702,14 +721,81 @@ func (src *Source) ParseFile(path string) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+// IsStdRepo returns true if src is part of standard library.
|
|
|
+func (src *Source) IsStdRepo() (bool, error) {
|
|
|
+ if len(src.Files) == 0 {
|
|
|
+ return false, errors.New("no input files provided")
|
|
|
+ }
|
|
|
+ abspath, err := filepath.Abs(src.Files[0])
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+ goroot := runtime.GOROOT()
|
|
|
+ if runtime.GOOS == "windows" {
|
|
|
+ abspath = strings.ToLower(abspath)
|
|
|
+ goroot = strings.ToLower(goroot)
|
|
|
+ }
|
|
|
+ sep := string(os.PathSeparator)
|
|
|
+ if !strings.HasSuffix(goroot, sep) {
|
|
|
+ goroot += sep
|
|
|
+ }
|
|
|
+ return strings.HasPrefix(abspath, goroot), nil
|
|
|
+}
|
|
|
+
|
|
|
// Generate output source file from a source set src.
|
|
|
func (src *Source) Generate(w io.Writer) error {
|
|
|
+ const (
|
|
|
+ pkgStd = iota // any package in std library
|
|
|
+ pkgXSysWindows // x/sys/windows package
|
|
|
+ pkgOther
|
|
|
+ )
|
|
|
+ isStdRepo, err := src.IsStdRepo()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ var pkgtype int
|
|
|
+ switch {
|
|
|
+ case isStdRepo:
|
|
|
+ pkgtype = pkgStd
|
|
|
+ case packageName == "windows":
|
|
|
+ // TODO: this needs better logic than just using package name
|
|
|
+ pkgtype = pkgXSysWindows
|
|
|
+ default:
|
|
|
+ pkgtype = pkgOther
|
|
|
+ }
|
|
|
+ if *systemDLL {
|
|
|
+ switch pkgtype {
|
|
|
+ case pkgStd:
|
|
|
+ src.Import("internal/syscall/windows/sysdll")
|
|
|
+ case pkgXSysWindows:
|
|
|
+ default:
|
|
|
+ src.ExternalImport("golang.org/x/sys/windows")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ src.ExternalImport("github.com/Microsoft/go-winio")
|
|
|
+ if packageName != "syscall" {
|
|
|
+ src.Import("syscall")
|
|
|
+ }
|
|
|
funcMap := template.FuncMap{
|
|
|
"packagename": packagename,
|
|
|
"syscalldot": syscalldot,
|
|
|
+ "newlazydll": func(dll string) string {
|
|
|
+ arg := "\"" + dll + ".dll\""
|
|
|
+ if !*systemDLL {
|
|
|
+ return syscalldot() + "NewLazyDLL(" + arg + ")"
|
|
|
+ }
|
|
|
+ switch pkgtype {
|
|
|
+ case pkgStd:
|
|
|
+ return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
|
|
|
+ case pkgXSysWindows:
|
|
|
+ return "NewLazySystemDLL(" + arg + ")"
|
|
|
+ default:
|
|
|
+ return "windows.NewLazySystemDLL(" + arg + ")"
|
|
|
+ }
|
|
|
+ },
|
|
|
}
|
|
|
t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
|
|
|
- err := t.Execute(w, src)
|
|
|
+ err = t.Execute(w, src)
|
|
|
if err != nil {
|
|
|
return errors.New("Failed to execute template: " + err.Error())
|
|
|
}
|
|
@@ -761,12 +847,41 @@ const srcTemplate = `
|
|
|
|
|
|
package {{packagename}}
|
|
|
|
|
|
-import "github.com/Microsoft/go-winio"
|
|
|
-import "unsafe"{{if syscalldot}}
|
|
|
-import "syscall"{{end}}
|
|
|
+import (
|
|
|
+{{range .StdLibImports}}"{{.}}"
|
|
|
+{{end}}
|
|
|
+
|
|
|
+{{range .ExternalImports}}"{{.}}"
|
|
|
+{{end}}
|
|
|
+)
|
|
|
|
|
|
var _ unsafe.Pointer
|
|
|
|
|
|
+// Do the interface allocations only once for common
|
|
|
+// Errno values.
|
|
|
+const (
|
|
|
+ errnoERROR_IO_PENDING = 997
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
|
|
|
+)
|
|
|
+
|
|
|
+// errnoErr returns common boxed Errno values, to prevent
|
|
|
+// allocations at runtime.
|
|
|
+func errnoErr(e {{syscalldot}}Errno) error {
|
|
|
+ switch e {
|
|
|
+ case 0:
|
|
|
+ return nil
|
|
|
+ case errnoERROR_IO_PENDING:
|
|
|
+ return errERROR_IO_PENDING
|
|
|
+ }
|
|
|
+ // TODO: add more here, after collecting data on the common
|
|
|
+ // error values see on Windows. (perhaps when running
|
|
|
+ // all.bat?)
|
|
|
+ return e
|
|
|
+}
|
|
|
+
|
|
|
var (
|
|
|
{{template "dlls" .}}
|
|
|
{{template "funcnames" .}})
|
|
@@ -775,7 +890,7 @@ var (
|
|
|
|
|
|
{{/* help functions */}}
|
|
|
|
|
|
-{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{syscalldot}}NewLazyDLL("{{.}}.dll")
|
|
|
+{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}}
|
|
|
{{end}}{{end}}
|
|
|
|
|
|
{{define "funcnames"}}{{range .Funcs}}{{if .IsNotDuplicate}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}"){{end}}
|
|
@@ -802,12 +917,13 @@ func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
|
|
|
|
|
|
{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
|
|
|
|
|
|
+{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
|
|
|
+
|
|
|
{{define "syscallcheck"}}{{if .ConfirmProc}}if {{.Rets.ErrorVarName}} = proc{{.DLLFuncName}}.Find(); {{.Rets.ErrorVarName}} != nil {
|
|
|
return
|
|
|
}
|
|
|
{{end}}{{end}}
|
|
|
|
|
|
-{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
|
|
|
|
|
|
{{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}}
|
|
|
{{end}}{{end}}
|