Browse code

Update hcsshim

Signed-off-by: Darren Stahl <darst@microsoft.com>

Darren Stahl authored on 2016/11/22 04:45:32
Showing 3 changed files
... ...
@@ -1,6 +1,6 @@
1 1
 # the following lines are in sorted order, FYI
2 2
 github.com/Azure/go-ansiterm 388960b655244e76e24c75f48631564eaefade62
3
-github.com/Microsoft/hcsshim v0.5.8
3
+github.com/Microsoft/hcsshim v0.5.9
4 4
 github.com/Microsoft/go-winio v0.3.6
5 5
 github.com/Sirupsen/logrus f76d643702a30fbffecdfe50831e11881c96ceb3 https://github.com/aaronlehmann/logrus
6 6
 github.com/davecgh/go-spew 6d212800a42e8ab5c146b8ace3490ee17e5225f9
... ...
@@ -57,6 +57,9 @@ import (
57 57
 	"io/ioutil"
58 58
 	"log"
59 59
 	"os"
60
+	"path/filepath"
61
+	"runtime"
62
+	"sort"
60 63
 	"strconv"
61 64
 	"strings"
62 65
 	"text/template"
... ...
@@ -65,6 +68,7 @@ import (
65 65
 var (
66 66
 	filename       = flag.String("output", "", "output file name (standard output if omitted)")
67 67
 	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
68
+	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
68 69
 )
69 70
 
70 71
 func trim(s string) string {
... ...
@@ -277,7 +281,7 @@ func (r *Rets) SetReturnValuesCode() string {
277 277
 func (r *Rets) useLongHandleErrorCode(retvar string) string {
278 278
 	const code = `if %s {
279 279
 		if e1 != 0 {
280
-			err = error(e1)
280
+			err = errnoErr(e1)
281 281
 		} else {
282 282
 			err = %sEINVAL
283 283
 		}
... ...
@@ -607,7 +611,6 @@ func (f *Fn) IsNotDuplicate() bool {
607 607
 		uniqDllFuncName[funcName] = true
608 608
 		return true
609 609
 	}
610
-
611 610
 	return false
612 611
 }
613 612
 
... ...
@@ -621,8 +624,20 @@ func (f *Fn) HelperName() string {
621 621
 
622 622
 // Source files and functions.
623 623
 type Source struct {
624
-	Funcs []*Fn
625
-	Files []string
624
+	Funcs           []*Fn
625
+	Files           []string
626
+	StdLibImports   []string
627
+	ExternalImports []string
628
+}
629
+
630
+func (src *Source) Import(pkg string) {
631
+	src.StdLibImports = append(src.StdLibImports, pkg)
632
+	sort.Strings(src.StdLibImports)
633
+}
634
+
635
+func (src *Source) ExternalImport(pkg string) {
636
+	src.ExternalImports = append(src.ExternalImports, pkg)
637
+	sort.Strings(src.ExternalImports)
626 638
 }
627 639
 
628 640
 // ParseFiles parses files listed in fs and extracts all syscall
... ...
@@ -632,6 +647,10 @@ func ParseFiles(fs []string) (*Source, error) {
632 632
 	src := &Source{
633 633
 		Funcs: make([]*Fn, 0),
634 634
 		Files: make([]string, 0),
635
+		StdLibImports: []string{
636
+			"unsafe",
637
+		},
638
+		ExternalImports: make([]string, 0),
635 639
 	}
636 640
 	for _, file := range fs {
637 641
 		if err := src.ParseFile(file); err != nil {
... ...
@@ -702,14 +721,81 @@ func (src *Source) ParseFile(path string) error {
702 702
 	return nil
703 703
 }
704 704
 
705
+// IsStdRepo returns true if src is part of standard library.
706
+func (src *Source) IsStdRepo() (bool, error) {
707
+	if len(src.Files) == 0 {
708
+		return false, errors.New("no input files provided")
709
+	}
710
+	abspath, err := filepath.Abs(src.Files[0])
711
+	if err != nil {
712
+		return false, err
713
+	}
714
+	goroot := runtime.GOROOT()
715
+	if runtime.GOOS == "windows" {
716
+		abspath = strings.ToLower(abspath)
717
+		goroot = strings.ToLower(goroot)
718
+	}
719
+	sep := string(os.PathSeparator)
720
+	if !strings.HasSuffix(goroot, sep) {
721
+		goroot += sep
722
+	}
723
+	return strings.HasPrefix(abspath, goroot), nil
724
+}
725
+
705 726
 // Generate output source file from a source set src.
706 727
 func (src *Source) Generate(w io.Writer) error {
728
+	const (
729
+		pkgStd         = iota // any package in std library
730
+		pkgXSysWindows        // x/sys/windows package
731
+		pkgOther
732
+	)
733
+	isStdRepo, err := src.IsStdRepo()
734
+	if err != nil {
735
+		return err
736
+	}
737
+	var pkgtype int
738
+	switch {
739
+	case isStdRepo:
740
+		pkgtype = pkgStd
741
+	case packageName == "windows":
742
+		// TODO: this needs better logic than just using package name
743
+		pkgtype = pkgXSysWindows
744
+	default:
745
+		pkgtype = pkgOther
746
+	}
747
+	if *systemDLL {
748
+		switch pkgtype {
749
+		case pkgStd:
750
+			src.Import("internal/syscall/windows/sysdll")
751
+		case pkgXSysWindows:
752
+		default:
753
+			src.ExternalImport("golang.org/x/sys/windows")
754
+		}
755
+	}
756
+	src.ExternalImport("github.com/Microsoft/go-winio")
757
+	if packageName != "syscall" {
758
+		src.Import("syscall")
759
+	}
707 760
 	funcMap := template.FuncMap{
708 761
 		"packagename": packagename,
709 762
 		"syscalldot":  syscalldot,
763
+		"newlazydll": func(dll string) string {
764
+			arg := "\"" + dll + ".dll\""
765
+			if !*systemDLL {
766
+				return syscalldot() + "NewLazyDLL(" + arg + ")"
767
+			}
768
+			switch pkgtype {
769
+			case pkgStd:
770
+				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
771
+			case pkgXSysWindows:
772
+				return "NewLazySystemDLL(" + arg + ")"
773
+			default:
774
+				return "windows.NewLazySystemDLL(" + arg + ")"
775
+			}
776
+		},
710 777
 	}
711 778
 	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
712
-	err := t.Execute(w, src)
779
+	err = t.Execute(w, src)
713 780
 	if err != nil {
714 781
 		return errors.New("Failed to execute template: " + err.Error())
715 782
 	}
... ...
@@ -761,12 +847,41 @@ const srcTemplate = `
761 761
 
762 762
 package {{packagename}}
763 763
 
764
-import "github.com/Microsoft/go-winio"
765
-import "unsafe"{{if syscalldot}}
766
-import "syscall"{{end}}
764
+import (
765
+{{range .StdLibImports}}"{{.}}"
766
+{{end}}
767
+
768
+{{range .ExternalImports}}"{{.}}"
769
+{{end}}
770
+)
767 771
 
768 772
 var _ unsafe.Pointer
769 773
 
774
+// Do the interface allocations only once for common
775
+// Errno values.
776
+const (
777
+	errnoERROR_IO_PENDING = 997
778
+)
779
+
780
+var (
781
+	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
782
+)
783
+
784
+// errnoErr returns common boxed Errno values, to prevent
785
+// allocations at runtime.
786
+func errnoErr(e {{syscalldot}}Errno) error {
787
+	switch e {
788
+	case 0:
789
+		return nil
790
+	case errnoERROR_IO_PENDING:
791
+		return errERROR_IO_PENDING
792
+	}
793
+	// TODO: add more here, after collecting data on the common
794
+	// error values see on Windows. (perhaps when running
795
+	// all.bat?)
796
+	return e
797
+}
798
+
770 799
 var (
771 800
 {{template "dlls" .}}
772 801
 {{template "funcnames" .}})
... ...
@@ -775,7 +890,7 @@ var (
775 775
 
776 776
 {{/* help functions */}}
777 777
 
778
-{{define "dlls"}}{{range .DLLs}}	mod{{.}} = {{syscalldot}}NewLazyDLL("{{.}}.dll")
778
+{{define "dlls"}}{{range .DLLs}}	mod{{.}} = {{newlazydll .}}
779 779
 {{end}}{{end}}
780 780
 
781 781
 {{define "funcnames"}}{{range .Funcs}}{{if .IsNotDuplicate}}	proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}"){{end}}
... ...
@@ -802,12 +917,13 @@ func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
802 802
 
803 803
 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
804 804
 
805
+{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
806
+
805 807
 {{define "syscallcheck"}}{{if .ConfirmProc}}if {{.Rets.ErrorVarName}} = proc{{.DLLFuncName}}.Find(); {{.Rets.ErrorVarName}} != nil {
806 808
     return
807 809
 }
808 810
 {{end}}{{end}}
809 811
 
810
-{{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
811 812
 
812 813
 {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
813 814
 {{end}}{{end}}
... ...
@@ -2,16 +2,45 @@
2 2
 
3 3
 package hcsshim
4 4
 
5
-import "github.com/Microsoft/go-winio"
6
-import "unsafe"
7
-import "syscall"
5
+import (
6
+	"syscall"
7
+	"unsafe"
8
+
9
+	"github.com/Microsoft/go-winio"
10
+	"golang.org/x/sys/windows"
11
+)
8 12
 
9 13
 var _ unsafe.Pointer
10 14
 
15
+// Do the interface allocations only once for common
16
+// Errno values.
17
+const (
18
+	errnoERROR_IO_PENDING = 997
19
+)
20
+
21
+var (
22
+	errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
23
+)
24
+
25
+// errnoErr returns common boxed Errno values, to prevent
26
+// allocations at runtime.
27
+func errnoErr(e syscall.Errno) error {
28
+	switch e {
29
+	case 0:
30
+		return nil
31
+	case errnoERROR_IO_PENDING:
32
+		return errERROR_IO_PENDING
33
+	}
34
+	// TODO: add more here, after collecting data on the common
35
+	// error values see on Windows. (perhaps when running
36
+	// all.bat?)
37
+	return e
38
+}
39
+
11 40
 var (
12
-	modole32     = syscall.NewLazyDLL("ole32.dll")
13
-	modiphlpapi  = syscall.NewLazyDLL("iphlpapi.dll")
14
-	modvmcompute = syscall.NewLazyDLL("vmcompute.dll")
41
+	modole32     = windows.NewLazySystemDLL("ole32.dll")
42
+	modiphlpapi  = windows.NewLazySystemDLL("iphlpapi.dll")
43
+	modvmcompute = windows.NewLazySystemDLL("vmcompute.dll")
15 44
 
16 45
 	procCoTaskMemFree                      = modole32.NewProc("CoTaskMemFree")
17 46
 	procSetCurrentThreadCompartmentId      = modiphlpapi.NewProc("SetCurrentThreadCompartmentId")