diff --git a/build/golang.go b/build/golang.go index 4b05d1b1..de7884a3 100644 --- a/build/golang.go +++ b/build/golang.go @@ -27,6 +27,7 @@ import ( "path/filepath" "github.com/loopholelabs/scale/compile/golang" + "github.com/loopholelabs/scale/extension" "github.com/loopholelabs/scale/scalefile" "github.com/loopholelabs/scale/scalefunc" "github.com/loopholelabs/scale/signature" @@ -68,6 +69,8 @@ type LocalGolangOptions struct { // Args are the optional arguments to pass to the compiler Args []string + + Extensions []extension.ExtensionInfo } func LocalGolang(options *LocalGolangOptions) (*scalefunc.Schema, error) { @@ -156,7 +159,7 @@ func LocalGolang(options *LocalGolangOptions) (*scalefunc.Schema, error) { _ = options.Storage.Delete(build) }() - modfile, err := golang.GenerateGoModfile(options.Scalefile, signatureDependencyPath, signatureDependencyVersion, options.SourceDirectory) + modfile, err := golang.GenerateGoModfile(options.Scalefile, signatureDependencyPath, signatureDependencyVersion, options.SourceDirectory, options.Extensions) if err != nil { return nil, fmt.Errorf("unable to generate go.mod file: %w", err) } diff --git a/compile/golang/generator.go b/compile/golang/generator.go index a887bf07..bc9ce6b2 100644 --- a/compile/golang/generator.go +++ b/compile/golang/generator.go @@ -22,14 +22,15 @@ import ( "github.com/loopholelabs/scale/version" "github.com/loopholelabs/scale/compile/golang/templates" + "github.com/loopholelabs/scale/extension" "github.com/loopholelabs/scale/scalefile" "github.com/loopholelabs/scale/signature" ) var generator *Generator -func GenerateGoModfile(packageSchema *scalefile.Schema, signatureImport string, signatureVersion string, functionImport string) ([]byte, error) { - return generator.GenerateGoModfile(packageSchema, signatureImport, signatureVersion, functionImport) +func GenerateGoModfile(packageSchema *scalefile.Schema, signatureImport string, signatureVersion string, functionImport string, extensions []extension.ExtensionInfo) ([]byte, error) { + return generator.GenerateGoModfile(packageSchema, signatureImport, signatureVersion, functionImport, extensions) } func GenerateGoMain(packageSchema *scalefile.Schema, signatureSchema *signature.Schema) ([]byte, error) { @@ -50,7 +51,7 @@ func New() *Generator { } } -func (g *Generator) GenerateGoModfile(packageSchema *scalefile.Schema, signatureImport string, signatureVersion string, functionImport string) ([]byte, error) { +func (g *Generator) GenerateGoModfile(packageSchema *scalefile.Schema, signatureImport string, signatureVersion string, functionImport string, extensions []extension.ExtensionInfo) ([]byte, error) { if signatureVersion == "" && !strings.HasPrefix(signatureImport, "/") && !strings.HasPrefix(signatureImport, "./") && !strings.HasPrefix(signatureImport, "../") { signatureImport = "./" + signatureImport } @@ -62,6 +63,7 @@ func (g *Generator) GenerateGoModfile(packageSchema *scalefile.Schema, signature buf := new(bytes.Buffer) err := g.template.ExecuteTemplate(buf, "mod.go.templ", map[string]interface{}{ "package_schema": packageSchema, + "extensions": extensions, "signature_import": signatureImport, "signature_version": signatureVersion, "function_import": functionImport, diff --git a/compile/golang/templates/mod.go.templ b/compile/golang/templates/mod.go.templ index e4c1d6a2..df708c6a 100644 --- a/compile/golang/templates/mod.go.templ +++ b/compile/golang/templates/mod.go.templ @@ -8,4 +8,10 @@ replace {{ .package_schema.Name }} v0.1.0 => {{ .function_import }} require ( signature v0.1.0 {{ .package_schema.Name }} v0.1.0 -) \ No newline at end of file +) + +{{ range $extension := .extensions -}} + replace {{ $extension.Name }} => {{ $extension.Path }} + + require {{ $extension.Name }} {{ $extension.Version }} +{{end -}} diff --git a/config.go b/config.go index 3372ba8b..b3523ad3 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,7 @@ import ( "io" "regexp" + extension "github.com/loopholelabs/scale-extension-interfaces" interfaces "github.com/loopholelabs/scale-signature-interfaces" "github.com/loopholelabs/scale/scalefunc" ) @@ -50,6 +51,7 @@ type Config[T interfaces.Signature] struct { stdout io.Writer stderr io.Writer rawOutput bool + extensions []extension.Extension } // NewConfig returns a new Scale Runtime Config @@ -85,6 +87,11 @@ func (c *Config[T]) validate() error { return nil } +func (c *Config[T]) WithExtension(e extension.Extension) *Config[T] { + c.extensions = append(c.extensions, e) + return c +} + func (c *Config[T]) WithSignature(newSignature interfaces.New[T]) *Config[T] { c.newSignature = newSignature return c diff --git a/extension/function.go b/extension/function.go new file mode 100644 index 00000000..d615bb5c --- /dev/null +++ b/extension/function.go @@ -0,0 +1,24 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package extension + +type FunctionSchema struct { + Name string `hcl:"name,label"` + Description string `hcl:"description,optional"` + Params string `hcl:"params,optional"` + Return string `hcl:"return,optional"` +} diff --git a/extension/generator/file.go b/extension/generator/file.go new file mode 100644 index 00000000..85ef2026 --- /dev/null +++ b/extension/generator/file.go @@ -0,0 +1,88 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package generator + +import ( + "bytes" + "io" + "io/fs" + "os" + "time" + + "golang.org/x/mod/zip" +) + +var _ zip.File = (*File)(nil) +var _ os.FileInfo = (*File)(nil) + +type File struct { + name string + path string + content []byte + reader *bytes.Reader + size int64 +} + +func NewFile(name string, path string, content []byte) File { + return File{ + name: name, + path: path, + content: content, + reader: bytes.NewReader(content), + size: int64(len(content)), + } +} + +func (g File) Name() string { + return g.name +} + +func (g File) Size() int64 { + return g.size +} + +func (g File) Mode() fs.FileMode { + return 0700 +} + +func (g File) ModTime() time.Time { + return time.Now() +} + +func (g File) IsDir() bool { + return false +} + +func (g File) Sys() any { + return g.content +} + +func (g File) Path() string { + return g.path +} + +func (g File) Lstat() (os.FileInfo, error) { + return g, nil +} + +func (g File) Open() (io.ReadCloser, error) { + return io.NopCloser(g.reader), nil +} + +func (g File) Data() []byte { + return g.content +} diff --git a/extension/generator/generator.go b/extension/generator/generator.go new file mode 100644 index 00000000..8d604095 --- /dev/null +++ b/extension/generator/generator.go @@ -0,0 +1,158 @@ +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package generator + +import ( + "bytes" + "encoding/hex" + + "github.com/loopholelabs/scale/extension" + "github.com/loopholelabs/scale/extension/generator/golang" + "github.com/loopholelabs/scale/extension/generator/rust" +) + +type GuestRegistryPackage struct { + GolangModule *bytes.Buffer + GolangModfile []byte + RustCrate *bytes.Buffer + RustCargofile []byte + TypescriptPackage *bytes.Buffer + TypescriptPackageJSON []byte +} + +type GuestLocalPackage struct { + GolangFiles []File + RustFiles []File + TypescriptFiles []File +} + +type HostRegistryPackage struct { + GolangModule *bytes.Buffer + GolangModfile []byte + TypescriptPackage *bytes.Buffer + TypescriptPackageJSON []byte +} + +type HostLocalPackage struct { + GolangFiles []File + TypescriptFiles []File +} + +type Options struct { + Extension *extension.Schema + + GolangPackageImportPath string + GolangPackageName string + GolangPackageVersion string + + RustPackageName string + RustPackageVersion string + + TypescriptPackageName string + TypescriptPackageVersion string +} + +func GenerateGuestLocal(options *Options) (*GuestLocalPackage, error) { + hash, err := options.Extension.Hash() + if err != nil { + return nil, err + } + hashString := hex.EncodeToString(hash) + + golangTypes, err := golang.GenerateTypes(options.Extension, options.GolangPackageName) + if err != nil { + return nil, err + } + + golangGuest, err := golang.GenerateGuest(options.Extension, options.GolangPackageName, options.GolangPackageVersion) + if err != nil { + return nil, err + } + + golangInterfaces, err := golang.GenerateInterfaces(options.Extension, options.GolangPackageName, options.GolangPackageVersion) + if err != nil { + return nil, err + } + + modfile, err := golang.GenerateModfile(options.GolangPackageName) + if err != nil { + return nil, err + } + + golangFiles := []File{ + NewFile("types.go", "types.go", golangTypes), + NewFile("guest.go", "guest.go", golangGuest), + NewFile("interfaces.go", "interfaces.go", golangInterfaces), + NewFile("go.mod", "go.mod", modfile), + } + + rustTypes, err := rust.GenerateTypes(options.Extension, options.RustPackageName) + if err != nil { + return nil, err + } + + rustGuest, err := rust.GenerateGuest(options.Extension, hashString, options.RustPackageName) + if err != nil { + return nil, err + } + + cargofile, err := rust.GenerateCargofile(options.RustPackageName, options.RustPackageVersion) + if err != nil { + return nil, err + } + + rustFiles := []File{ + NewFile("types.rs", "types.rs", rustTypes), + NewFile("guest.rs", "guest.rs", rustGuest), + NewFile("Cargo.toml", "Cargo.toml", cargofile), + } + + return &GuestLocalPackage{ + GolangFiles: golangFiles, + RustFiles: rustFiles, + }, nil +} + +func GenerateHostLocal(options *Options) (*HostLocalPackage, error) { + golangTypes, err := golang.GenerateTypes(options.Extension, options.GolangPackageName) + if err != nil { + return nil, err + } + + golangHost, err := golang.GenerateHost(options.Extension, options.GolangPackageName, options.GolangPackageVersion) + if err != nil { + return nil, err + } + + golangInterfaces, err := golang.GenerateInterfaces(options.Extension, options.GolangPackageName, options.GolangPackageVersion) + if err != nil { + return nil, err + } + + modfile, err := golang.GenerateModfile(options.GolangPackageName) + if err != nil { + return nil, err + } + + golangFiles := []File{ + NewFile("types.go", "types.go", golangTypes), + NewFile("host.go", "host.go", golangHost), + NewFile("interfaces.go", "interfaces.go", golangInterfaces), + NewFile("go.mod", "go.mod", modfile), + } + + return &HostLocalPackage{ + GolangFiles: golangFiles, + }, nil +} diff --git a/extension/generator/golang/generated.txt b/extension/generator/golang/generated.txt new file mode 100644 index 00000000..168c93d4 --- /dev/null +++ b/extension/generator/golang/generated.txt @@ -0,0 +1,270 @@ +// Code generated by scale-signature v0.4.1, DO NOT EDIT. +// output: types + +package types + +import ( + "errors" + "github.com/loopholelabs/polyglot" +) + +var ( + NilDecode = errors.New("cannot decode into a nil root struct") + InvalidEnum = errors.New("invalid enum value") +) + +type HttpConfig struct { + Timeout int32 +} + +func NewHttpConfig() *HttpConfig { + return &HttpConfig{ + + Timeout: 60, + } +} + +func (x *HttpConfig) Encode(b *polyglot.Buffer) { + e := polyglot.Encoder(b) + if x == nil { + e.Nil() + } else { + + e.Int32(x.Timeout) + + } +} + +func DecodeHttpConfig(x *HttpConfig, b []byte) (*HttpConfig, error) { + d := polyglot.GetDecoder(b) + defer d.Return() + return _decodeHttpConfig(x, d) +} + +func _decodeHttpConfig(x *HttpConfig, d *polyglot.Decoder) (*HttpConfig, error) { + if d.Nil() { + return nil, nil + } + + err, _ := d.Error() + if err != nil { + return nil, err + } + + if x == nil { + x = NewHttpConfig() + } + + x.Timeout, err = d.Int32() + if err != nil { + return nil, err + } + + return x, nil +} + +type HttpResponse struct { + Headers map[string]StringList + + StatusCode int32 + + Body []byte +} + +func NewHttpResponse() *HttpResponse { + return &HttpResponse{ + + Headers: make(map[string]StringList), + + StatusCode: 0, + + Body: make([]byte, 0, 0), + } +} + +func (x *HttpResponse) Encode(b *polyglot.Buffer) { + e := polyglot.Encoder(b) + if x == nil { + e.Nil() + } else { + + e.Map(uint32(len(x.Headers)), polyglot.StringKind, polyglot.AnyKind) + for k, v := range x.Headers { + e.String(k) + v.Encode(b) + } + + e.Int32(x.StatusCode) + + e.Bytes(x.Body) + + } +} + +func DecodeHttpResponse(x *HttpResponse, b []byte) (*HttpResponse, error) { + d := polyglot.GetDecoder(b) + defer d.Return() + return _decodeHttpResponse(x, d) +} + +func _decodeHttpResponse(x *HttpResponse, d *polyglot.Decoder) (*HttpResponse, error) { + if d.Nil() { + return nil, nil + } + + err, _ := d.Error() + if err != nil { + return nil, err + } + + if x == nil { + x = NewHttpResponse() + } + + mapSizeHeaders, err := d.Map(polyglot.StringKind, polyglot.AnyKind) + if err != nil { + return nil, err + } + + if uint32(len(x.Headers)) != mapSizeHeaders { + x.Headers = make(map[string]StringList, mapSizeHeaders) + } + + for i := uint32(0); i < mapSizeHeaders; i++ { + k, err := d.String() + if err != nil { + return nil, err + } + v, err := _decodeStringList(nil, d) + if err != nil { + return nil, err + } + x.Headers[k] = *v + } + + x.StatusCode, err = d.Int32() + if err != nil { + return nil, err + } + + x.Body, err = d.Bytes(nil) + if err != nil { + return nil, err + } + + return x, nil +} + +type StringList struct { + Values []string +} + +func NewStringList() *StringList { + return &StringList{ + + Values: make([]string, 0, 0), + } +} + +func (x *StringList) Encode(b *polyglot.Buffer) { + e := polyglot.Encoder(b) + if x == nil { + e.Nil() + } else { + + e.Slice(uint32(len(x.Values)), polyglot.StringKind) + for _, a := range x.Values { + e.String(a) + } + + } +} + +func DecodeStringList(x *StringList, b []byte) (*StringList, error) { + d := polyglot.GetDecoder(b) + defer d.Return() + return _decodeStringList(x, d) +} + +func _decodeStringList(x *StringList, d *polyglot.Decoder) (*StringList, error) { + if d.Nil() { + return nil, nil + } + + err, _ := d.Error() + if err != nil { + return nil, err + } + + if x == nil { + x = NewStringList() + } + + sliceSizeValues, err := d.Slice(polyglot.StringKind) + if err != nil { + return nil, err + } + + if uint32(len(x.Values)) != sliceSizeValues { + x.Values = make([]string, sliceSizeValues) + } + + for i := uint32(0); i < sliceSizeValues; i++ { + x.Values[i], err = d.String() + if err != nil { + return nil, err + } + } + + return x, nil +} + +type ConnectionDetails struct { + Url string +} + +func NewConnectionDetails() *ConnectionDetails { + return &ConnectionDetails{ + + Url: "https://google.com", + } +} + +func (x *ConnectionDetails) Encode(b *polyglot.Buffer) { + e := polyglot.Encoder(b) + if x == nil { + e.Nil() + } else { + + e.String(x.Url) + + } +} + +func DecodeConnectionDetails(x *ConnectionDetails, b []byte) (*ConnectionDetails, error) { + d := polyglot.GetDecoder(b) + defer d.Return() + return _decodeConnectionDetails(x, d) +} + +func _decodeConnectionDetails(x *ConnectionDetails, d *polyglot.Decoder) (*ConnectionDetails, error) { + if d.Nil() { + return nil, nil + } + + err, _ := d.Error() + if err != nil { + return nil, err + } + + if x == nil { + x = NewConnectionDetails() + } + + x.Url, err = d.String() + if err != nil { + return nil, err + } + + return x, nil +} diff --git a/extension/generator/golang/generator.go b/extension/generator/golang/generator.go new file mode 100644 index 00000000..2ce81a21 --- /dev/null +++ b/extension/generator/golang/generator.go @@ -0,0 +1,284 @@ +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package golang + +import ( + "bytes" + "go/format" + "text/template" + + polyglotVersion "github.com/loopholelabs/polyglot/version" + interfacesVersion "github.com/loopholelabs/scale-extension-interfaces/version" + "github.com/loopholelabs/scale/extension" + + scaleVersion "github.com/loopholelabs/scale/version" + + "github.com/loopholelabs/scale/extension/generator/golang/templates" + "github.com/loopholelabs/scale/signature/generator/utils" +) + +const ( + defaultPackageName = "types" +) + +var generator *Generator + +func GenerateTypes(schema *extension.Schema, packageName string) ([]byte, error) { + return generator.GenerateTypes(schema, packageName) +} + +func GenerateInterfaces(schema *extension.Schema, packageName string, version string) ([]byte, error) { + return generator.GenerateInterfaces(schema, packageName, version) +} + +func GenerateGuest(schema *extension.Schema, packageName string, version string) ([]byte, error) { + return generator.GenerateGuest(schema, packageName, version) +} + +func GenerateModfile(packageName string) ([]byte, error) { + return generator.GenerateModfile(packageName) +} + +func GenerateHost(schema *extension.Schema, packageName string, version string) ([]byte, error) { + return generator.GenerateHost(schema, packageName, version) +} + +func init() { + var err error + generator, err = New() + if err != nil { + panic(err) + } +} + +// Generator is the go generator +type Generator struct { + templ *template.Template +} + +// New creates a new go generator +func New() (*Generator, error) { + templ, err := template.New("").Funcs(templateFunctions()).ParseFS(templates.FS, "*.go.templ") + if err != nil { + return nil, err + } + + return &Generator{ + templ: templ, + }, nil +} + +// Generate generates the go code +func (g *Generator) GenerateTypes(schema *extension.Schema, packageName string) ([]byte, error) { + if packageName == "" { + packageName = defaultPackageName + } + + ext, err := schema.CloneWithDisabledAccessorsValidatorsAndModifiers() + if err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + err = g.templ.ExecuteTemplate(buf, "types.go.templ", map[string]any{ + "signature_schema": ext, + "generator_version": scaleVersion.Version(), + "package_name": packageName, + }) + if err != nil { + return nil, err + } + + return format.Source(buf.Bytes()) +} + +func (g *Generator) GenerateInterfaces(schema *extension.Schema, packageName string, version string) ([]byte, error) { + if packageName == "" { + packageName = defaultPackageName + } + + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "interfaces.go.templ", map[string]any{ + "schema": schema, + "version": version, + "package": packageName, + }) + if err != nil { + return nil, err + } + + return format.Source(buf.Bytes()) +} + +// GenerateGuest generates the guest bindings +func (g *Generator) GenerateGuest(schema *extension.Schema, packageName string, version string) ([]byte, error) { + if packageName == "" { + packageName = defaultPackageName + } + + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "guest.go.templ", map[string]any{ + "schema": schema, + "version": version, + "package": packageName, + }) + if err != nil { + return nil, err + } + + return format.Source(buf.Bytes()) +} + +// GenerateModfile generates the modfile for the signature +func (g *Generator) GenerateModfile(packageImportPath string) ([]byte, error) { + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "mod.go.templ", map[string]any{ + "polyglot_version": polyglotVersion.Version(), + "scale_extension_interfaces_version": interfacesVersion.Version(), + "package_import_path": packageImportPath, + }) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// GenerateHost generates the host bindings +func (g *Generator) GenerateHost(schema *extension.Schema, packageName string, version string) ([]byte, error) { + if packageName == "" { + packageName = defaultPackageName + } + + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "host.go.templ", map[string]any{ + "schema": schema, + "version": version, + "package": packageName, + }) + if err != nil { + return nil, err + } + + return format.Source(buf.Bytes()) +} + +func templateFunctions() template.FuncMap { + return template.FuncMap{ + "IsInterface": isInterface, + "Primitive": primitive, + "IsPrimitive": extension.ValidPrimitiveType, + "PolyglotPrimitive": polyglotPrimitive, + "PolyglotPrimitiveEncode": polyglotPrimitiveEncode, + "PolyglotPrimitiveDecode": polyglotPrimitiveDecode, + "Deref": func(i *bool) bool { return *i }, + "LowerFirst": func(s string) string { return string(s[0]+32) + s[1:] }, + "Params": utils.Params, + } +} + +func isInterface(schema *extension.Schema, s string) bool { + for _, i := range schema.Interfaces { + if i.Name == s { + return true + } + } + return false +} + +func primitive(t string) string { + switch t { + case "string", "int32", "int64", "uint32", "uint64", "float32", "float64", "bool": + return t + case "bytes": + return "[]byte" + default: + return "" + } +} + +func polyglotPrimitive(t string) string { + switch t { + case "string": + return "polyglot.StringKind" + case "int32": + return "polyglot.Int32Kind" + case "int64": + return "polyglot.Int64Kind" + case "uint32": + return "polyglot.Uint32Kind" + case "uint64": + return "polyglot.Uint64Kind" + case "float32": + return "polyglot.Float32Kind" + case "float64": + return "polyglot.Float64Kind" + case "bool": + return "polyglot.BoolKind" + case "bytes": + return "polyglot.BytesKind" + default: + return "polyglot.AnyKind" + } +} + +func polyglotPrimitiveEncode(t string) string { + switch t { + case "string": + return "String" + case "int32": + return "Int32" + case "int64": + return "Int64" + case "uint32": + return "Uint32" + case "uint64": + return "Uint64" + case "float32": + return "Float32" + case "float64": + return "Float64" + case "bool": + return "Bool" + case "bytes": + return "Bytes" + default: + return "" + } +} + +func polyglotPrimitiveDecode(t string) string { + switch t { + case "string": + return "String" + case "int32": + return "Int32" + case "int64": + return "Int64" + case "uint32": + return "Uint32" + case "uint64": + return "Uint64" + case "float32": + return "Float32" + case "float64": + return "Float64" + case "bool": + return "Bool" + case "bytes": + return "Bytes" + default: + return "" + } +} diff --git a/extension/generator/golang/generator_test.go b/extension/generator/golang/generator_test.go new file mode 100644 index 00000000..429290ab --- /dev/null +++ b/extension/generator/golang/generator_test.go @@ -0,0 +1,72 @@ +//go:build !integration + +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package golang + +import ( + "os" + "testing" + + "github.com/loopholelabs/scale/extension" + "github.com/stretchr/testify/require" +) + +func TestGenerator(t *testing.T) { + s := new(extension.Schema) + err := s.Decode([]byte(extension.MasterTestingSchema)) + require.NoError(t, err) + + require.NoError(t, s.Validate()) + + packageName := "extfetch" + + interfaces, err := GenerateInterfaces(s, packageName, "v0.1.0") + require.NoError(t, err) + // os.WriteFile("./interfaces.txt", interfaces, 0644) + + expInterfaces, err := os.ReadFile("./interfaces.txt") + require.NoError(t, err) + require.Equal(t, string(expInterfaces), string(interfaces)) + + formatted, err := GenerateTypes(s, "types") + require.NoError(t, err) + // os.WriteFile("./generated.txt", formatted, 0644) + + expTypes, err := os.ReadFile("./generated.txt") + require.NoError(t, err) + require.Equal(t, string(expTypes), string(formatted)) + + host, err := GenerateHost(s, packageName, "v0.1.0") + require.NoError(t, err) + // os.WriteFile("./host.txt", host, 0644) + expHost, err := os.ReadFile("./host.txt") + require.NoError(t, err) + require.Equal(t, string(expHost), string(host)) + + guest, err := GenerateGuest(s, packageName, "v0.1.0") + require.NoError(t, err) + // os.WriteFile("./guest.txt", guest, 0644) + expGuest, err := os.ReadFile("./guest.txt") + require.NoError(t, err) + require.Equal(t, string(expGuest), string(guest)) + + mod, err := GenerateModfile(packageName) + require.NoError(t, err) + // os.WriteFile("./modfile.txt", mod, 0644) + expMod, err := os.ReadFile("./modfile.txt") + require.NoError(t, err) + require.Equal(t, string(expMod), string(mod)) + +} diff --git a/extension/generator/golang/guest.txt b/extension/generator/golang/guest.txt new file mode 100644 index 00000000..6c869373 --- /dev/null +++ b/extension/generator/golang/guest.txt @@ -0,0 +1,102 @@ +// Code generated by scale-extension v0.1.0, DO NOT EDIT. +// schema: HttpFetch:alpha +// output: extfetch + +package extfetch + +import ( + "github.com/loopholelabs/polyglot" + "unsafe" +) + +var ( + writeBuffer = polyglot.NewBuffer() + readBuffer []byte +) + +//export ext_HttpFetch_Resize +//go:linkname ext_HttpFetch_Resize +func ext_HttpFetch_Resize(size uint32) uint32 { + readBuffer = make([]byte, size) + //if uint32(cap(readBuffer)) < size { + // readBuffer = append(make([]byte, 0, uint32(len(readBuffer))+size), readBuffer...) + //} + //readBuffer = readBuffer[:size] + return uint32(uintptr(unsafe.Pointer(&readBuffer[0]))) +} + +// Define any interfaces we need here... +// Also define structs we can use to hold instanceId + +// Define concrete types with a hidden instanceId + +type _HttpConnector struct { + instanceId uint64 +} + +func (d *_HttpConnector) Fetch(params *ConnectionDetails) (HttpResponse, error) { + + // First we take the params, serialize them. + writeBuffer.Reset() + params.Encode(writeBuffer) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + off := uint32(unsafePtr) + l := uint32(writeBuffer.Len()) + + // Now make the call to the host. + ext_HttpFetch_HttpConnector_Fetch(d.instanceId, off, l) + // IF the return type is a model, we should read the data from the read buffer. + + ret := &HttpResponse{} + r, err := DecodeHttpResponse(ret, readBuffer) + return *r, err + +} + +//export ext_HttpFetch_HttpConnector_Fetch +//go:linkname ext_HttpFetch_HttpConnector_Fetch +func ext_HttpFetch_HttpConnector_Fetch(instance uint64, offset uint32, length uint32) uint64 + +// Define any global functions here... + +//export ext_HttpFetch_New +//go:linkname ext_HttpFetch_New +func ext_HttpFetch_New(instance uint64, offset uint32, length uint32) uint64 + +func New(params *HttpConfig) (HttpConnector, error) { + // First we take the params, serialize them. + writeBuffer.Reset() + params.Encode(writeBuffer) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + off := uint32(unsafePtr) + l := uint32(writeBuffer.Len()) + + // Now make the call to the host. + v := ext_HttpFetch_New(0, off, l) + // IF the return type is an interface return ifc, which contains hidden instanceId. + + // TODO: Handle error from host. In this case there'll be an error in the readBuffer + + ret := &_HttpConnector{ + instanceId: v, + } + + return ret, nil + +} + +// Error serializes an error into the global writeBuffer and returns a pointer to the buffer and its size +// +// Users should not use this method. +func Error(err error) (uint32, uint32) { + writeBuffer.Reset() + polyglot.Encoder(writeBuffer).Error(err) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + return uint32(unsafePtr), uint32(writeBuffer.Len()) +} diff --git a/extension/generator/golang/host.txt b/extension/generator/golang/host.txt new file mode 100644 index 00000000..b2b45b17 --- /dev/null +++ b/extension/generator/golang/host.txt @@ -0,0 +1,142 @@ +// Code generated by scale-extension v0.1.0, DO NOT EDIT. +// schema: HttpFetch:alpha +// output: extfetch + +package extfetch + +import ( + "errors" + "github.com/loopholelabs/polyglot" + "sync" + "sync/atomic" + + extension "github.com/loopholelabs/scale-extension-interfaces" +) + +const identifier = "HttpFetch:alpha" + +// Write an error to the scale function guest buffer. +func hostError(mem extension.ModuleMemory, resize extension.Resizer, err error) { + b := polyglot.NewBuffer() + polyglot.Encoder(b).Error(err) + + writeBuffer, err := resize("ext_HttpFetch_Resize", uint64(b.Len())) + + if err != nil { + panic(err) + } + + if !mem.Write(uint32(writeBuffer), b.Bytes()) { + panic(err) + } +} + +type hostExt struct { + functions map[string]extension.InstallableFunc + host *HttpFetchHost +} + +func (he *hostExt) Init() map[string]extension.InstallableFunc { + return he.functions +} + +func (he *hostExt) Reset() { + // Reset any instances that have been created. + + he.host.instances_HttpConnector = make(map[uint64]HttpConnector) + +} + +func New(impl HttpFetchIfc) extension.Extension { + hostWrapper := &HttpFetchHost{impl: impl} + + fns := make(map[string]extension.InstallableFunc) + + // Add global functions to the runtime + + fns["ext_HttpFetch_New"] = hostWrapper.host_ext_HttpFetch_New + + hostWrapper.instances_HttpConnector = make(map[uint64]HttpConnector) + + fns["ext_HttpFetch_HttpConnector_Fetch"] = hostWrapper.host_ext_HttpFetch_HttpConnector_Fetch + + return &hostExt{ + functions: fns, + host: hostWrapper, + } +} + +type HttpFetchHost struct { + impl HttpFetchIfc + + gid_HttpConnector uint64 + instancesLock_HttpConnector sync.Mutex + instances_HttpConnector map[uint64]HttpConnector +} + +// Global functions + +func (h *HttpFetchHost) host_ext_HttpFetch_New(mem extension.ModuleMemory, resize extension.Resizer, params []uint64) { + ptr := uint32(params[0]) + length := uint32(params[1]) + data, _ := mem.Read(ptr, length) + + cd := &HttpConfig{} + cd, err := DecodeHttpConfig(cd, data) + if err != nil { + hostError(mem, resize, err) + } + + // Call the implementation + r, err := h.impl.New(cd) + if err != nil { + hostError(mem, resize, err) + } + + id := atomic.AddUint64(&h.gid_HttpConnector, 1) + h.instancesLock_HttpConnector.Lock() + h.instances_HttpConnector[id] = r + h.instancesLock_HttpConnector.Unlock() + + // Return the ID + params[0] = id + +} + +func (h *HttpFetchHost) host_ext_HttpFetch_HttpConnector_Fetch(mem extension.ModuleMemory, resize extension.Resizer, params []uint64) { + h.instancesLock_HttpConnector.Lock() + r, ok := h.instances_HttpConnector[params[0]] + h.instancesLock_HttpConnector.Unlock() + if !ok { + hostError(mem, resize, errors.New("Instance ID not found!")) + } + + ptr := uint32(params[1]) + length := uint32(params[2]) + data, _ := mem.Read(ptr, length) + + cd := &ConnectionDetails{} + cd, err := DecodeConnectionDetails(cd, data) + if err != nil { + hostError(mem, resize, err) + } + + resp, err := r.Fetch(cd) + if err != nil { + hostError(mem, resize, err) + } + + b := polyglot.NewBuffer() + resp.Encode(b) + + writeBuffer, err := resize("ext_HttpFetch_Resize", uint64(b.Len())) + + if err != nil { + hostError(mem, resize, err) + } + + if !mem.Write(uint32(writeBuffer), b.Bytes()) { + hostError(mem, resize, err) + } + +} diff --git a/extension/generator/golang/interfaces.txt b/extension/generator/golang/interfaces.txt new file mode 100644 index 00000000..5d37b4d8 --- /dev/null +++ b/extension/generator/golang/interfaces.txt @@ -0,0 +1,15 @@ +// Code generated by scale-extension v0.1.0, DO NOT EDIT. +// schema: HttpFetch:alpha +// output: extfetch + +package extfetch + +// Interface to the extension impl. This is what the implementor should create + +type HttpFetchIfc interface { + New(params *HttpConfig) (HttpConnector, error) +} + +type HttpConnector interface { + Fetch(*ConnectionDetails) (HttpResponse, error) +} diff --git a/extension/generator/golang/modfile.txt b/extension/generator/golang/modfile.txt new file mode 100644 index 00000000..e31c6312 --- /dev/null +++ b/extension/generator/golang/modfile.txt @@ -0,0 +1,8 @@ +module extfetch + +go 1.20 + +require ( + github.com/loopholelabs/polyglot v1.1.3 + github.com/loopholelabs/scale-extension-interfaces v0.1.0 +) \ No newline at end of file diff --git a/extension/generator/golang/templates/arrays.go.templ b/extension/generator/golang/templates/arrays.go.templ new file mode 100644 index 00000000..0d256b1d --- /dev/null +++ b/extension/generator/golang/templates/arrays.go.templ @@ -0,0 +1,116 @@ +{{ define "go_arrays_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }} []{{ Primitive $type }} + {{- else }} + {{ .Name }} []{{ Primitive $type }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_arrays_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }}: make([]{{ Primitive $type }}, 0, {{ .InitialSize }}), + {{- else }} + {{ .Name }}: make([]{{ Primitive $type }}, 0, {{ .InitialSize }}), + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_arrays_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + e.Slice(uint32(len(x.{{ LowerFirst .Name }})), {{ PolyglotPrimitive $type }}) + for _, a := range x.{{ LowerFirst .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(a) + } + {{- else }} + e.Slice(uint32(len(x.{{ .Name }})), {{ PolyglotPrimitive $type }}) + for _, a := range x.{{ .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(a) + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_arrays_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + sliceSize{{ LowerFirst .Name }}, err := d.Slice({{ PolyglotPrimitive .Type }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != sliceSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make([]{{ Primitive $type }}, sliceSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < sliceSize{{ LowerFirst .Name }}; i++ { + x.{{ LowerFirst .Name }}[i], err = d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + } + {{- else }} + sliceSize{{ .Name }}, err := d.Slice({{ PolyglotPrimitive $type }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != sliceSize{{ .Name }} { + x.{{ .Name }} = make([]{{ Primitive $type }}, sliceSize{{ .Name }}) + } + + for i := uint32(0); i < sliceSize{{ .Name }}; i++ { + x.{{ .Name }}[i], err = d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_bytesarrays_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + sliceSize{{ LowerFirst .Name }}, err := d.Slice({{ PolyglotPrimitive .Type }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != sliceSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make([]{{ Primitive $type }}, sliceSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < sliceSize{{ LowerFirst .Name }}; i++ { + x.{{ LowerFirst .Name }}[i], err = d.{{ PolyglotPrimitiveDecode $type }}(nil) + if err != nil { + return nil, err + } + } + {{- else }} + sliceSize{{ .Name }}, err := d.Slice({{ PolyglotPrimitive $type }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != sliceSize{{ .Name }} { + x.{{ .Name }} = make([]{{ Primitive $type }}, sliceSize{{ .Name }}) + } + + for i := uint32(0); i < sliceSize{{ .Name }}; i++ { + x.{{ .Name }}[i], err = d.{{ PolyglotPrimitiveDecode $type }}(nil) + if err != nil { + return nil, err + } + } + {{- end -}} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/enumarrays.go.templ b/extension/generator/golang/templates/enumarrays.go.templ new file mode 100644 index 00000000..a9e984b9 --- /dev/null +++ b/extension/generator/golang/templates/enumarrays.go.templ @@ -0,0 +1,89 @@ +{{ define "go_enumarrays_definition" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + type {{ .Name }} uint32 + {{ $current_enum := . }} + const ( + {{- range $index, $value := .Values }} + {{ $current_enum.Name }}{{ $value }} {{ .Reference }} = {{ $index }} + {{ end }} + ) + {{ end }} +{{ end }} + +{{ define "go_enumarrays_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }} []{{ .Reference }} + {{- else }} + {{ .Name }} []{{ .Reference }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enumarrays_new_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + {{- if .Accessor }} + {{ LowerFirst .Name }}: make([]{{ .Reference }}, 0, {{ .InitialSize }}), + {{- else }} + {{ .Name }}: make([]{{ .Reference }}, 0, {{ .InitialSize }}), + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enumarrays_encode" }} + {{- range .EnumArrays }} + {{- if (Deref .Accessor) }} + e.Slice(uint32(len(x.{{ LowerFirst .Name }})), polyglot.Uint32Kind) + for _, a := range x.{{ LowerFirst .Name }} { + e.Uint32(uint32(a)) + } + {{- else }} + e.Slice(uint32(len(x.{{ .Name }})), polyglot.Uint32Kind) + for _, a := range x.{{ .Name }} { + e.Uint32(uint32(a)) + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enumarrays_decode" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + {{- if (Deref .Accessor) }} + sliceSize{{ LowerFirst .Name }}, err := d.Slice(polyglot.Uint32Kind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != sliceSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make([]{{ .Reference }}, sliceSize{{ LowerFirst .Name }}) + } + + val, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }}[i] = val + {{- else }} + sliceSize{{ .Name }}, err := d.Slice(polyglot.Uint32Kind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != sliceSize{{ .Name }} { + x.{{ .Name }} = make([]{{ .Reference }}, sliceSize{{ .Name }}) + } + + for i := uint32(0); i < sliceSize{{ .Name }}; i++ { + val, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + x.{{ .Name }}[i] = val + } + {{- end -}} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/enummaps.go.templ b/extension/generator/golang/templates/enummaps.go.templ new file mode 100644 index 00000000..d6a0ba9c --- /dev/null +++ b/extension/generator/golang/templates/enummaps.go.templ @@ -0,0 +1,205 @@ +{{ define "go_enummaps_definition" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + type {{ .Name }} uint32 + {{ $current_enum := . }} + const ( + {{- range $index, $value := .Values }} + {{ $current_enum.Name }}{{ $value }} {{ $current_enum.Name }} = {{ $index }} + {{ end }} + ) + {{ end }} +{{ end }} + +{{ define "go_enummaps_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + {{ LowerFirst .Name }} map[{{ .Reference }}]{{ Primitive .Value }} + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + {{ LowerFirst .Name }} map[{{ .Reference }}]{{ .Value }} + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + {{ .Name }} map[{{ .Reference }}]{{ Primitive .Value }} + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + {{ .Name }} map[{{ .Reference }}]{{ .Value }} + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_enummaps_new_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + {{ LowerFirst .Name }}: make(map[{{ .Reference }}]{{ Primitive .Value }}), + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + {{ LowerFirst .Name }}: make(map[{{ .Reference }}]{{ .Value }}), + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + {{ .Name }}: make(map[{{ .Reference }}]{{ Primitive .Value }}), + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + {{ .Name }}: make(map[{{ .Reference }}]{{ .Value }}), + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_enummaps_encode" }} + {{- range .EnumMaps }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + e.Map(uint32(len(x.{{ LowerFirst .Name }})), polyglot.Uint32Kind, {{ PolyglotPrimitive .Value }}) + for k, v := range x.{{ LowerFirst .Name }} { + e.Uint32(uint32(k)) + e.{{ PolyglotPrimitiveEncode .Value }}(v) + } + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + e.Map(uint32(len(x.{{ LowerFirst .Name }})), polyglot.Uint32Kind, polyglot.AnyKind) + for k, v := range x.{{ LowerFirst .Name }} { + e.Uint32(uint32(k)) + v.Encode(e) + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + e.Map(uint32(len(x.{{ .Name }})), polyglot.Uint32Kind, {{ PolyglotPrimitive .Value }}) + for k, v := range x.{{ .Name }} { + e.Uint32(uint32(k)) + e.{{ PolyglotPrimitiveEncode .Value }}(v) + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + e.Map(uint32(len(x.{{ .Name }})), polyglot.Uint32Kind, polyglot.AnyKind) + for k, v := range x.{{ .Name }} { + e.Uint32(uint32(k)) + v.Encode(b) + } + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_enummaps_decode" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + mapSize{{ LowerFirst .Name }}, err := d.Map(polyglot.Uint32Kind, {{ PolyglotPrimitive .Value }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != mapSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make(map[{{ .Reference }}]{{ Primitive .Value }}, mapSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < mapSize{{ LowerFirst .Name }}; i++ { + var k {{ .Reference }} + enumValue{{ LowerFirst .Name }}, err := d.Uint32() + if err != nil { + return nil, err + } + {{ $current_enum := . }} + switch {{ .Reference }}(enumValue{{ LowerFirst .Name }}) { + {{- range $index, $value := .Values }} + case {{ .Reference }}{{ $value }}: + k = {{ .Reference }}{{ $value }} + {{- end }} + default: + return nil, InvalidEnum + } + x.{{ LowerFirst .Name }}[k], err = d.{{ PolyglotPrimitiveDecode .Value }}() + if err != nil { + return nil, err + } + } + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + mapSize{{ LowerFirst .Name }}, err := d.Map(polyglot.Uint32Kind, polyglot.AnyKind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != mapSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make(map[{{ .Reference }}]{{ .Value }}, mapSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < mapSize{{ LowerFirst .Name }}; i++ { + var k {{ .Reference }} + enumValue{{ LowerFirst .Name }}, err := d.Uint32() + if err != nil { + return nil, err + } + {{ $current_enum := . }} + switch {{ .Reference }}(enumValue{{ LowerFirst .Name }}) { + {{- range $index, $value := .Values }} + case {{ .Reference }}{{ $value }}: + k = {{ .Reference }}{{ $value }} + {{- end }} + default: + return nil, InvalidEnum + } + v, err := _decode{{ .Value }}(nil, d) + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }}[k] = *v + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + mapSize{{ .Name }}, err := d.Map(polyglot.Uint32Kind, {{ PolyglotPrimitive .Value }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != mapSize{{ .Name }} { + x.{{ .Name }} = make(map[{{ .Reference }}]{{ Primitive .Value }}, mapSize{{ .Name }}) + } + + for i := uint32(0); i < mapSize{{ .Name }}; i++ { + k, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + x.{{ .Name }}[k], err = d.{{ PolyglotPrimitiveDecode .Value }}() + if err != nil { + return nil, err + } + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + mapSize{{ .Name }}, err := d.Map(polyglot.Uint32Kind, polyglot.AnyKind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != mapSize{{ .Name }} { + x.{{ .Name }} = make(map[{{ .Reference }}]{{ .Value }}, mapSize{{ .Name }}) + } + + for i := uint32(0); i < mapSize{{ .Name }}; i++ { + k, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + v, err := _decode{{ .Value }}(nil, d) + if err != nil { + return nil, err + } + x.{{ .Name }}[k] = *v + } + {{- end }} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/enums.go.templ b/extension/generator/golang/templates/enums.go.templ new file mode 100644 index 00000000..eb96205b --- /dev/null +++ b/extension/generator/golang/templates/enums.go.templ @@ -0,0 +1,90 @@ +{{ define "go_enums_definition" }} + type {{ .Name }} uint32 + {{ $current_enum := . }} + const ( + {{- range $index, $value := .Values }} + {{ $current_enum.Name }}{{ $value }} {{ $current_enum.Name }} = {{ $index }} + {{ end }} + ) + + func decode{{ .Name }}(d *polyglot.Decoder) ({{ .Name }}, error) { + enumValue, err := d.Uint32() + if err != nil { + return 0, err + } + switch {{ .Name }}(enumValue) { + {{- range $index, $value := .Values }} + case {{ $current_enum.Name }}{{ $value }}: + return {{ $current_enum.Name }}{{ $value }}, nil + {{- end }} + default: + return 0, InvalidEnum + } + } +{{ end }} + +{{ define "go_enums_struct_reference" }} + {{ $current_model := . }} + {{- range .Enums }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }} {{ .Reference }} + {{- else }} + {{ .Name }} {{ .Reference }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enums_new_struct_reference" }} + {{ $current_model := . }} + {{- range .Enums }} + {{- if .Accessor }} + {{ LowerFirst .Name }}: {{ .Reference }}{{ .Default }}, + {{- else }} + {{ .Name }}: {{ .Reference }}{{ .Default }}, + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enums_encode" }} + {{- range .Enums }} + {{- if (Deref .Accessor) }} + e.Uint32(uint32(x.{{ LowerFirst .Name }})) + {{- else }} + e.Uint32(uint32(x.{{ .Name }})) + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enums_decode" }} + {{- range .Model.Enums }} + {{- if (Deref .Accessor) }} + result, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }} = result + {{- else }} + result, err := decode{{ .Reference }}(d) + if err != nil { + return nil, err + } + x.{{ .Name }} = result + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_enums_accessor" }} + {{ $current_model := . }} + {{- range .Enums }} + {{- if .Accessor }} + func (x *{{ $current_model.Name }}) Get{{ .Name }}() ({{ .Reference }}, error) { + return x.{{ LowerFirst .Name }}, nil + } + + func (x *{{ $current_model.Name }}) Set{{ .Name }}(v {{ .Reference }}) error { + x.{{ LowerFirst .Name }} = v + return nil + } + {{- end -}} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/guest.go.templ b/extension/generator/golang/templates/guest.go.templ new file mode 100644 index 00000000..9e234fb1 --- /dev/null +++ b/extension/generator/golang/templates/guest.go.templ @@ -0,0 +1,138 @@ +// Code generated by scale-extension {{ .version }}, DO NOT EDIT. +// schema: {{ .schema.Name }}:{{ .schema.Tag }} +// output: {{ .package }} + +package {{ .package }} + +import ( + "github.com/loopholelabs/polyglot" + "unsafe" +) + +var ( + writeBuffer = polyglot.NewBuffer() + readBuffer []byte +) + +//export ext_{{ .schema.Name }}_Resize +//go:linkname ext_{{ .schema.Name }}_Resize +func ext_{{ .schema.Name }}_Resize(size uint32) uint32 { + readBuffer = make([]byte, size) + //if uint32(cap(readBuffer)) < size { + // readBuffer = append(make([]byte, 0, uint32(len(readBuffer))+size), readBuffer...) + //} + //readBuffer = readBuffer[:size] + return uint32(uintptr(unsafe.Pointer(&readBuffer[0]))) +} + +{{ $schema := .schema }} + +// Define any interfaces we need here... +// Also define structs we can use to hold instanceId + +{{ range $ifc := .schema.Interfaces }} + +// Define concrete types with a hidden instanceId + +type _{{ $ifc.Name }} struct { + instanceId uint64 +} + +{{ range $fn := $ifc.Functions }} +func (d *_{{ $ifc.Name }}) {{ $fn.Name }}(params *{{ $fn.Params }}) ({{ $fn.Return }}, error) { + + // First we take the params, serialize them. + writeBuffer.Reset() + params.Encode(writeBuffer) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + off := uint32(unsafePtr) + l := uint32(writeBuffer.Len()) + + // Now make the call to the host. + + {{- if (IsInterface $schema $fn.Return) }} + v := ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}(d.instanceId, off, l) + // IF the return type is an interface return ifc, which contains hidden instanceId. + + // TODO: Handle error from host. In this case there'll be an error in the readBuffer + + ret := &_{{ $fn.Return }}{ + instanceId: v, + } + + return ret, nil + {{ else }} + ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}(d.instanceId, off, l) + // IF the return type is a model, we should read the data from the read buffer. + + ret := &{{ $fn.Return }}{} + r, err := Decode{{ $fn.Return }}(ret, readBuffer) + return *r, err + {{ end }} +} + +//export ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }} +//go:linkname ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }} +func ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}(instance uint64, offset uint32, length uint32) uint64 + +{{ end }} + +{{ end }} + +// Define any global functions here... + +{{ range $fn := .schema.Functions }} + +//export ext_{{ $schema.Name }}_{{ $fn.Name }} +//go:linkname ext_{{ $schema.Name }}_{{ $fn.Name }} +func ext_{{ $schema.Name }}_{{ $fn.Name }}(instance uint64, offset uint32, length uint32) uint64 + +func {{ $fn.Name }}(params *{{ $fn.Params }}) ({{ $fn.Return }}, error) { + // First we take the params, serialize them. + writeBuffer.Reset() + params.Encode(writeBuffer) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + off := uint32(unsafePtr) + l := uint32(writeBuffer.Len()) + + // Now make the call to the host. + + {{- if (IsInterface $schema $fn.Return) }} + v := ext_{{ $schema.Name }}_{{ $fn.Name }}(0, off, l) + // IF the return type is an interface return ifc, which contains hidden instanceId. + + // TODO: Handle error from host. In this case there'll be an error in the readBuffer + + ret := &_{{ $fn.Return }}{ + instanceId: v, + } + + return ret, nil + {{ else }} + ext_{{ $schema.Name }}_{{ $fn.Name }}(0, off, l) + // IF the return type is a model, we should read the data from the read buffer. + + ret := &{{ $fn.Return }}{} + r, err := Decode{{ $fn.Return }}(ret, readBuffer) + return *r, err + {{ end }} + +} + +{{ end }} + +// Error serializes an error into the global writeBuffer and returns a pointer to the buffer and its size +// +// Users should not use this method. +func Error(err error) (uint32, uint32) { + writeBuffer.Reset() + polyglot.Encoder(writeBuffer).Error(err) + underlying := writeBuffer.Bytes() + ptr := &underlying[0] + unsafePtr := uintptr(unsafe.Pointer(ptr)) + return uint32(unsafePtr), uint32(writeBuffer.Len()) +} diff --git a/extension/generator/golang/templates/host.go.templ b/extension/generator/golang/templates/host.go.templ new file mode 100644 index 00000000..91209aa2 --- /dev/null +++ b/extension/generator/golang/templates/host.go.templ @@ -0,0 +1,198 @@ +// Code generated by scale-extension {{ .version }}, DO NOT EDIT. +// schema: {{ .schema.Name }}:{{ .schema.Tag }} +// output: {{ .package }} + +package {{ .package }} + +import ( + "errors" + "sync/atomic" + "sync" + "github.com/loopholelabs/polyglot" + + extension "github.com/loopholelabs/scale-extension-interfaces" +) + +const identifier = "{{ .schema.Name }}:{{ .schema.Tag }}" + +// Write an error to the scale function guest buffer. +func hostError(mem extension.ModuleMemory, resize extension.Resizer, err error) { + b := polyglot.NewBuffer() + polyglot.Encoder(b).Error(err) + + writeBuffer, err := resize("ext_HttpFetch_Resize", uint64(b.Len())) + + if err != nil { + panic(err) + } + + if !mem.Write(uint32(writeBuffer), b.Bytes()) { + panic(err) + } +} + +{{ $schema := .schema }} + +type hostExt struct { + functions map[string]extension.InstallableFunc + host *{{ .schema.Name }}Host +} + +func (he *hostExt) Init() map[string]extension.InstallableFunc { + return he.functions +} + +func (he *hostExt) Reset() { + // Reset any instances that have been created. + {{ range $ifc := .schema.Interfaces }} + he.host.instances_{{ $ifc.Name }} = make(map[uint64]{{ $ifc.Name }}) + {{ end }} +} + +func New(impl {{ .schema.Name }}Ifc) extension.Extension { + hostWrapper := &{{ .schema.Name }}Host{ impl: impl } + + fns := make(map[string]extension.InstallableFunc) + +// Add global functions to the runtime +{{ range $fn := .schema.Functions }} + fns["ext_{{ $schema.Name }}_{{ $fn.Name }}"] = hostWrapper.host_ext_{{ $schema.Name }}_{{ $fn.Name }} +{{ end }} + +{{ range $ifc := .schema.Interfaces }} + hostWrapper.instances_{{ $ifc.Name }} = make(map[uint64]{{ $ifc.Name }}) + + {{ range $fn := $ifc.Functions }} + + fns["ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}"] = hostWrapper.host_ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }} + + {{ end }} +{{ end }} + + return &hostExt{ + functions: fns, + host: hostWrapper, + } +} + +type {{ .schema.Name }}Host struct { + impl {{ .schema.Name }}Ifc +{{ range $ifc := .schema.Interfaces }} + + gid_{{ $ifc.Name }} uint64 + instancesLock_{{ $ifc.Name }} sync.Mutex + instances_{{ $ifc.Name }} map[uint64]{{ $ifc.Name }} + +{{ end }} + +} + +// Global functions +{{ range $fn := .schema.Functions }} + +func (h *{{ $schema.Name }}Host) host_ext_{{ $schema.Name }}_{{ $fn.Name}}(mem extension.ModuleMemory, resize extension.Resizer, params []uint64) { + ptr := uint32(params[0]) + length := uint32(params[1]) + data, _ := mem.Read(ptr, length) + + cd := &{{ $fn.Params }}{} + cd, err := Decode{{ $fn.Params }}(cd, data) + if err != nil { + hostError(mem, resize, err) + } + + // Call the implementation + r, err := h.impl.{{ $fn.Name }}(cd) + if err!=nil { + hostError(mem, resize, err) + } + +{{- if (IsInterface $schema $fn.Return) }} + + id := atomic.AddUint64(&h.gid_{{ $fn.Return }}, 1) + h.instancesLock_{{ $fn.Return }}.Lock() + h.instances_{{ $fn.Return }}[id] = r + h.instancesLock_{{ $fn.Return }}.Unlock() + + // Return the ID + params[0] = id + +{{ else }} + + b := polyglot.NewBuffer() + r.Encode(b) + + writeBuffer, err := resize("ext_{{ $schema.Name }}_Resize", uint64(b.Len())) + + if err != nil { + hostError(mem, resize, err) + } + + if !mem.Write(uint32(writeBuffer), b.Bytes()) { + hostError(mem, resize, err) + } + +{{ end }} +} + +{{ end }} + + +{{ range $ifc := .schema.Interfaces }} + +{{ range $fn := $ifc.Functions }} + +func (h *{{ $schema.Name }}Host) host_ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}(mem extension.ModuleMemory, resize extension.Resizer, params []uint64) { + h.instancesLock_{{ $ifc.Name }}.Lock() + r, ok := h.instances_{{ $ifc.Name }}[params[0]] + h.instancesLock_{{ $ifc.Name }}.Unlock() + if !ok { + hostError(mem, resize, errors.New("Instance ID not found!")) + } + + ptr := uint32(params[1]) + length := uint32(params[2]) + data, _ := mem.Read(ptr, length) + + cd := &{{ $fn.Params }}{} + cd, err := Decode{{ $fn.Params }}(cd, data) + if err != nil { + hostError(mem, resize, err) + } + + resp, err := r.{{ $fn.Name }}(cd) + if err != nil { + hostError(mem, resize, err) + } + + +{{- if (IsInterface $schema $fn.Return) }} + + id := atomic.AddUint64(&h.gid_{{ $fn.Return }}, 1) + h.instancesLock_{{ $fn.Return }}.Lock() + h.instances_{{ $fn.Return }}[id] = resp + h.instancesLock_{{ $fn.Return }}.Unlock() + + // Return the ID + params[0] = id + +{{ else }} + + b := polyglot.NewBuffer() + resp.Encode(b) + + writeBuffer, err := resize("ext_{{ $schema.Name }}_Resize", uint64(b.Len())) + + if err != nil { + hostError(mem, resize, err) + } + + if !mem.Write(uint32(writeBuffer), b.Bytes()) { + hostError(mem, resize, err) + } + +{{ end }} +} + + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/interfaces.go.templ b/extension/generator/golang/templates/interfaces.go.templ new file mode 100644 index 00000000..4da9d46f --- /dev/null +++ b/extension/generator/golang/templates/interfaces.go.templ @@ -0,0 +1,29 @@ +// Code generated by scale-extension {{ .version }}, DO NOT EDIT. +// schema: {{ .schema.Name }}:{{ .schema.Tag }} +// output: {{ .package }} + +package {{ .package }} + + +// Interface to the extension impl. This is what the implementor should create + +type {{ .schema.Name }}Ifc interface { +{{ range $fn := .schema.Functions }} + {{ $fn.Name }}(params *{{ $fn.Params }}) ({{ $fn.Return }}, error) +{{ end }} +} + + +{{ range $ifc := .schema.Interfaces }} + +type {{ $ifc.Name }} interface { + +{{ range $fn := $ifc.Functions }} + + {{ $fn.Name }}(*{{ $fn.Params }}) ({{ $fn.Return }}, error) + +{{ end }} + +} + +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/maps.go.templ b/extension/generator/golang/templates/maps.go.templ new file mode 100644 index 00000000..bbfe4639 --- /dev/null +++ b/extension/generator/golang/templates/maps.go.templ @@ -0,0 +1,173 @@ +{{ define "go_maps_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + {{ LowerFirst .Name }} map[{{ Primitive $type }}]{{ Primitive .Value }} + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + {{ LowerFirst .Name }} map[{{ Primitive $type }}]{{ .Value }} + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + {{ .Name }} map[{{ Primitive $type }}]{{ Primitive .Value }} + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + {{ .Name }} map[{{ Primitive $type }}]{{ .Value }} + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_maps_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + {{ LowerFirst .Name }}: make(map[{{ Primitive $type }}]{{ Primitive .Value }}), + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + {{ LowerFirst .Name }}: make(map[{{ Primitive $type }}]{{ .Value }}), + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + {{ .Name }}: make(map[{{ Primitive $type }}]{{ Primitive .Value }}), + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + {{ .Name }}: make(map[{{ Primitive $type }}]{{ .Value }}), + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_maps_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + e.Map(uint32(len(x.{{ LowerFirst .Name }})), {{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }}) + for k, v := range x.{{ LowerFirst .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(k) + e.{{ PolyglotPrimitiveEncode .Value }}(v) + } + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + e.Map(uint32(len(x.{{ LowerFirst .Name }})), {{ PolyglotPrimitive $type }}, polyglot.AnyKind) + for k, v := range x.{{ LowerFirst .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(k) + v.Encode(b) + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + e.Map(uint32(len(x.{{ .Name }})), {{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }}) + for k, v := range x.{{ .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(k) + e.{{ PolyglotPrimitiveEncode .Value }}(v) + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + e.Map(uint32(len(x.{{ .Name }})), {{ PolyglotPrimitive $type }}, polyglot.AnyKind) + for k, v := range x.{{ .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(k) + v.Encode(b) + } + {{- end }} + {{ end }} +{{ end }} + +{{ define "go_maps_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + mapSize{{ LowerFirst .Name }}, err := d.Map({{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != mapSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make(map[{{ Primitive $type }}]{{ Primitive .Value }}, mapSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < mapSize{{ LowerFirst .Name }}; i++ { + k, err := d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }}[k], err = d.{{ PolyglotPrimitiveDecode .Value }}() + if err != nil { + return nil, err + } + } + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + mapSize{{ LowerFirst .Name }}, err := d.Map({{ PolyglotPrimitive $type }}, polyglot.AnyKind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ LowerFirst .Name }})) != mapSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make(map[{{ Primitive $type }}]{{ .Value }}, mapSize{{ LowerFirst .Name }}) + } + + for i := uint32(0); i < mapSize{{ LowerFirst .Name }}; i++ { + k, err := d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + v, err := _decode{{ .Value }}(nil, d) + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }}[k] = *v + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + mapSize{{ .Name }}, err := d.Map({{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }}) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != mapSize{{ .Name }} { + x.{{ .Name }} = make(map[{{ Primitive $type }}]{{ Primitive .Value }}, mapSize{{ .Name }}) + } + + for i := uint32(0); i < mapSize{{ .Name }}; i++ { + k, err := d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + x.{{ .Name }}[k], err = d.{{ PolyglotPrimitiveDecode .Value }}() + if err != nil { + return nil, err + } + } + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + mapSize{{ .Name }}, err := d.Map({{ PolyglotPrimitive $type }}, polyglot.AnyKind) + if err != nil { + return nil, err + } + + if uint32(len(x.{{ .Name }})) != mapSize{{ .Name }} { + x.{{ .Name }} = make(map[{{ Primitive $type }}]{{ .Value }}, mapSize{{ .Name }}) + } + + for i := uint32(0); i < mapSize{{ .Name }}; i++ { + k, err := d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + v, err := _decode{{ .Value }}(nil, d) + if err != nil { + return nil, err + } + x.{{ .Name }}[k] = *v + } + {{- end }} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/mod.go.templ b/extension/generator/golang/templates/mod.go.templ new file mode 100644 index 00000000..85c10b0b --- /dev/null +++ b/extension/generator/golang/templates/mod.go.templ @@ -0,0 +1,8 @@ +module {{ .package_import_path }} + +go 1.20 + +require ( + github.com/loopholelabs/polyglot {{ .polyglot_version }} + github.com/loopholelabs/scale-extension-interfaces {{ .scale_extension_interfaces_version }} +) \ No newline at end of file diff --git a/extension/generator/golang/templates/modelarrays.go.templ b/extension/generator/golang/templates/modelarrays.go.templ new file mode 100644 index 00000000..71b8e7a8 --- /dev/null +++ b/extension/generator/golang/templates/modelarrays.go.templ @@ -0,0 +1,87 @@ +{{ define "go_modelarrays_struct_reference" }} + {{- range .ModelArrays }} + {{- if .Accessor }} + {{ LowerFirst .Name }} []{{ .Reference }} + {{- else }} + {{ .Name }} []{{ .Reference }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_modelarrays_new_struct_reference" }} + {{- range .ModelArrays }} + {{- if .Accessor }} + {{ LowerFirst .Name }}: make([]{{ .Reference }}, {{ .InitialSize }}), + {{- else }} + {{ .Name }}: make([]{{ .Reference }}, 0, {{ .InitialSize }}), + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_modelarrays_encode" }} + {{- range .ModelArrays }} + {{- if .Accessor }} + e.Slice(uint32(len(x.{{ LowerFirst .Name }})), polyglot.AnyKind) + for _, a := range x.{{ LowerFirst .Name }} { + a.Encode(b) + } + {{- else }} + e.Slice(uint32(len(x.{{ .Name }})), polyglot.AnyKind) + for _, a := range x.{{ .Name }} { + a.Encode(b) + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_modelarrays_decode" }} + {{- range .ModelArrays }} + {{- if .Accessor }} + sliceSize{{ LowerFirst .Name }}, err := d.Slice(polyglot.AnyKind) + if err != nil { + return nil, err + } + if uint32(len(x.{{ LowerFirst .Name }})) != sliceSize{{ LowerFirst .Name }} { + x.{{ LowerFirst .Name }} = make([]{{ .Reference }}, sliceSize{{ LowerFirst .Name }}) + } + for i := uint32(0); i < sliceSize{{ LowerFirst .Name }}; i++ { + v, err := _decode{{ .Reference }}(nil, d) + if err != nil { + return nil, err + } + x.{{ LowerFirst .Name }}[i] = *v + } + {{- else }} + sliceSize{{ .Name }}, err := d.Slice(polyglot.AnyKind) + if err != nil { + return nil, err + } + if uint32(len(x.{{ .Name }})) != sliceSize{{ .Name }} { + x.{{ .Name }} = make([]{{ .Reference }}, sliceSize{{ .Name }}) + } + for i := uint32(0); i < sliceSize{{ .Name }}; i++ { + v, err := _decode{{ .Reference }}(nil, d) + if err != nil { + return nil, err + } + x.{{ .Name }}[i] = *v + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_modelarrays_accessor" }} + {{ $current_model := . }} + {{- range .ModelArrays }} + {{- if .Accessor }} + func (x *{{ $current_model.Name }}) Get{{ .Name }}() ([]{{ .Reference }}, error) { + return x.{{ LowerFirst .Name }}, nil + } + + func (x *{{ $current_model.Name }}) Set{{ .Name }}(v []{{ .Reference }}) error { + x.{{ LowerFirst .Name }} = v + return nil + } + {{- end -}} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/models.go.templ b/extension/generator/golang/templates/models.go.templ new file mode 100644 index 00000000..e35e81a0 --- /dev/null +++ b/extension/generator/golang/templates/models.go.templ @@ -0,0 +1,61 @@ +{{ define "go_models_struct_reference" }} + {{- range .Models }} + {{- if .Accessor }} + {{ LowerFirst .Name }} *{{ .Reference }} + {{- else }} + {{ .Name }} *{{ .Reference }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_models_new_struct_reference" }} + {{- range .Models }} + {{- if .Accessor }} + {{ LowerFirst .Name }}: New{{ .Reference }}(), + {{- else }} + {{ .Name }}: New{{ .Reference }}(), + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_models_encode" }} + {{- range .Models }} + {{- if .Accessor }} + x.{{ LowerFirst .Name }}.Encode(b) + {{- else }} + x.{{ .Name }}.Encode(b) + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_models_decode" }} + {{- range .Models }} + {{- if .Accessor }} + x.{{ LowerFirst .Name }}, err = _decode{{ .Reference }}(nil, d) + if err != nil { + return nil, err + } + {{- else }} + x.{{ .Name }}, err = _decode{{ .Reference }}(nil, d) + if err != nil { + return nil, err + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_models_accessor" }} + {{ $current_model := . }} + {{- range .Models }} + {{- if .Accessor }} + func (x *{{ $current_model.Name }}) Get{{ .Name }}() (*{{ .Reference }}, error) { + return x.{{ LowerFirst .Name }}, nil + } + + func (x *{{ $current_model.Name }}) Set{{ .Name }}(v *{{ .Reference }}) error { + x.{{ LowerFirst .Name }} = v + return nil + } + {{- end -}} + {{ end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/golang/templates/primitives.go.templ b/extension/generator/golang/templates/primitives.go.templ new file mode 100644 index 00000000..ba8205a2 --- /dev/null +++ b/extension/generator/golang/templates/primitives.go.templ @@ -0,0 +1,126 @@ +{{ define "go_primitives_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }} {{ Primitive $type }} + {{- else }} + {{ .Name }} {{ Primitive $type }} + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_primitives_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }}: {{ .Default }}, + {{- else }} + {{ .Name }}: {{ .Default }}, + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_strings_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }}: "{{ .Default }}", + {{- else }} + {{ .Name }}: "{{ .Default }}", + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_bytes_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ LowerFirst .Name }}: make([]byte, 0, {{ .InitialSize }}), + {{- else }} + {{ .Name }}: make([]byte, 0, {{ .InitialSize }}), + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_primitives_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + e.{{ PolyglotPrimitiveEncode $type }}(x.{{ LowerFirst .Name }}) + {{- else }} + e.{{ PolyglotPrimitiveEncode $type }}(x.{{ .Name }}) + {{- end -}} + {{ end }} +{{ end}} + +{{ define "go_primitives_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + x.{{ LowerFirst .Name }}, err = d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + {{- else }} + x.{{ .Name }}, err = d.{{ PolyglotPrimitiveDecode $type }}() + if err != nil { + return nil, err + } + {{- end -}} + {{ end }} +{{ end}} + +{{ define "go_bytes_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + x.{{ LowerFirst .Name }}, err = d.{{ PolyglotPrimitiveDecode $type }}(nil) + if err != nil { + return nil, err + } + {{- else }} + x.{{ .Name }}, err = d.{{ PolyglotPrimitiveDecode $type }}(nil) + if err != nil { + return nil, err + } + {{- end -}} + {{ end }} +{{ end}} + +{{ define "go_numbers_accessor" }} + {{ $type := .Type }} + {{ $model := .Model }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + func (x *{{ $model.Name }}) Get{{ .Name }}() ({{ $type }}, error) { + return x.{{ LowerFirst .Name }}, nil + } + + func (x *{{ $model.Name }}) Set{{ .Name }}(v {{ $type }}) error { + {{- template "go_numbers_limit_validator" .LimitValidator }} + x.{{ LowerFirst .Name }} = v + return nil + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "go_strings_accessor" }} + {{ $type := .Type }} + {{ $model := .Model }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + func (x *{{ $model.Name }}) Get{{ .Name }}() ({{ $type }}, error) { + return x.{{ LowerFirst .Name }}, nil + } + + func (x *{{ $model.Name }}) Set{{ .Name }}(v {{ $type }}) error { + {{- template "go_regex_validator" .RegexValidator }} + {{- template "go_length_validator" .LengthValidator }} + {{- template "go_case_modifier" .CaseModifier }} + x.{{ LowerFirst .Name }} = v + return nil + } + {{- end -}} + {{ end }} +{{ end }} diff --git a/extension/generator/golang/templates/templates.go b/extension/generator/golang/templates/templates.go new file mode 100644 index 00000000..fdfb581a --- /dev/null +++ b/extension/generator/golang/templates/templates.go @@ -0,0 +1,19 @@ +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package templates + +import "embed" + +//go:embed * +var FS embed.FS diff --git a/extension/generator/golang/templates/types.go.templ b/extension/generator/golang/templates/types.go.templ new file mode 100644 index 00000000..27be5545 --- /dev/null +++ b/extension/generator/golang/templates/types.go.templ @@ -0,0 +1,235 @@ +// Code generated by scale-signature {{ .generator_version }}, DO NOT EDIT. +// output: {{ .package_name }} + +package {{ .package_name }} + +import ( + "github.com/loopholelabs/polyglot" + "errors" + {{ if or (.signature_schema.HasLengthValidator) (.signature_schema.HasRegexValidator) (.signature_schema.HasLimitValidator) }}"fmt"{{ end }} + {{ if .signature_schema.HasRegexValidator }}"regexp"{{ end }} + {{ if .signature_schema.HasCaseModifier }}"strings"{{ end }} +) + +var ( + NilDecode = errors.New("cannot decode into a nil root struct") + InvalidEnum = errors.New("invalid enum value") +) + +{{ range .signature_schema.Enums }} + {{ template "go_enums_definition" . }} +{{- end }} +{{- $allEnums := .signature_schema.Enums }} + +{{- range .signature_schema.Models -}} + {{- if .Description }} + // {{ .Name }}: {{ .Description }} + {{ end -}} + + type {{ .Name }} struct { + {{ template "go_models_struct_reference" . }} + {{ template "go_modelarrays_struct_reference" . }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Strings "Type" "string" }} + {{ template "go_arrays_struct_reference" Params "Entries" .StringArrays "Type" "string" }} + {{ template "go_maps_struct_reference" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Int32s "Type" "int32" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "go_maps_struct_reference" Params "Entries" .Int32Maps "Type" "int32" }} + + + {{ template "go_primitives_struct_reference" Params "Entries" .Int64s "Type" "int64" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "go_maps_struct_reference" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "go_maps_struct_reference" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "go_maps_struct_reference" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Float32s "Type" "float32" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Float64s "Type" "float64" }} + {{ template "go_arrays_struct_reference" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "go_enums_struct_reference" . }} + {{ template "go_enumarrays_struct_reference" . }} + {{ template "go_enummaps_struct_reference" . }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "go_arrays_struct_reference" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "go_primitives_struct_reference" Params "Entries" .Bools "Type" "bool" }} + {{ template "go_arrays_struct_reference" Params "Entries" .BoolArrays "Type" "bool" }} + } + + func New{{ .Name }}() *{{ .Name }} { + return &{{ .Name }}{ + {{ template "go_models_new_struct_reference" . }} + {{ template "go_modelarrays_new_struct_reference" . }} + + {{ template "go_strings_new_struct_reference" Params "Entries" .Strings }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .StringArrays "Type" "string" }} + {{ template "go_maps_new_struct_reference" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Int32s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "go_maps_new_struct_reference" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Int64s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "go_maps_new_struct_reference" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Uint32s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "go_maps_new_struct_reference" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Uint64s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "go_maps_new_struct_reference" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Float32s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Float64s }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "go_enums_new_struct_reference" . }} + {{ template "go_enumarrays_new_struct_reference" . }} + {{ template "go_enummaps_new_struct_reference" . }} + + {{ template "go_bytes_new_struct_reference" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "go_primitives_new_struct_reference" Params "Entries" .Bools }} + {{ template "go_arrays_new_struct_reference" Params "Entries" .BoolArrays "Type" "bool" }} + } + } + + func (x *{{ .Name }}) Encode(b *polyglot.Buffer) { + e := polyglot.Encoder(b) + if x == nil { + e.Nil() + } else { + {{ template "go_models_encode" . }} + {{ template "go_modelarrays_encode" . }} + + {{ template "go_primitives_encode" Params "Entries" .Strings "Type" "string" }} + {{ template "go_arrays_encode" Params "Entries" .StringArrays "Type" "string" }} + {{ template "go_maps_encode" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "go_primitives_encode" Params "Entries" .Int32s "Type" "int32" }} + {{ template "go_arrays_encode" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "go_maps_encode" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "go_primitives_encode" Params "Entries" .Int64s "Type" "int64" }} + {{ template "go_arrays_encode" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "go_maps_encode" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "go_primitives_encode" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "go_arrays_encode" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "go_maps_encode" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "go_primitives_encode" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "go_arrays_encode" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "go_maps_encode" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "go_primitives_encode" Params "Entries" .Float32s "Type" "float32" }} + {{ template "go_arrays_encode" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "go_primitives_encode" Params "Entries" .Float64s "Type" "float64" }} + {{ template "go_arrays_encode" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "go_enums_encode" . }} + {{ template "go_enumarrays_encode" . }} + {{ template "go_enummaps_encode" . }} + + {{ template "go_primitives_encode" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "go_arrays_encode" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "go_primitives_encode" Params "Entries" .Bools "Type" "bool" }} + {{ template "go_arrays_encode" Params "Entries" .BoolArrays "Type" "bool" }} + } + } + + func Decode{{ .Name }}(x *{{ .Name }}, b []byte) (*{{ .Name }}, error) { + d := polyglot.GetDecoder(b) + defer d.Return() + return _decode{{ .Name }}(x, d) + } + + func _decode{{ .Name }}(x *{{ .Name }}, d *polyglot.Decoder) (*{{ .Name }}, error) { + if d.Nil() { + return nil, nil + } + + err, _ := d.Error() + if err != nil { + return nil, err + } + + if x == nil { + x = New{{ .Name }}() + } + + {{ template "go_models_decode" . }} + {{ template "go_modelarrays_decode" . }} + + {{ template "go_primitives_decode" Params "Entries" .Strings "Type" "string" }} + {{ template "go_arrays_decode" Params "Entries" .StringArrays "Type" "string" }} + {{ template "go_maps_decode" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "go_primitives_decode" Params "Entries" .Int32s "Type" "int32" }} + {{ template "go_arrays_decode" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "go_maps_decode" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "go_primitives_decode" Params "Entries" .Int64s "Type" "int64" }} + {{ template "go_arrays_decode" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "go_maps_decode" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "go_primitives_decode" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "go_arrays_decode" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "go_maps_decode" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "go_primitives_decode" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "go_arrays_decode" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "go_maps_decode" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "go_primitives_decode" Params "Entries" .Float32s "Type" "float32" }} + {{ template "go_arrays_decode" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "go_primitives_decode" Params "Entries" .Float64s "Type" "float64" }} + {{ template "go_arrays_decode" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "go_enums_decode" Params "Model" . "Enums" $allEnums }} + {{ template "go_enumarrays_decode" . }} + {{ template "go_enummaps_decode" . }} + + {{ template "go_bytes_decode" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "go_bytesarrays_decode" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "go_primitives_decode" Params "Entries" .Bools "Type" "bool" }} + {{ template "go_arrays_decode" Params "Entries" .BoolArrays "Type" "bool" }} + + return x, nil + } + + {{ template "go_models_accessor" . }} + {{ template "go_modelarrays_accessor" . }} + + {{ template "go_enums_accessor" . }} + + {{ template "go_strings_accessor" Params "Model" . "Entries" .Strings "Type" "string" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Int32s "Type" "int32" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Int64s "Type" "int64" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Uint32s "Type" "uint32" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Uint64s "Type" "uint64" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Float32s "Type" "float32" }} + {{ template "go_numbers_accessor" Params "Model" . "Entries" .Float64s "Type" "float32" }} + +{{ end -}} \ No newline at end of file diff --git a/extension/generator/golang/templates/validators.go.templ b/extension/generator/golang/templates/validators.go.templ new file mode 100644 index 00000000..0b17f187 --- /dev/null +++ b/extension/generator/golang/templates/validators.go.templ @@ -0,0 +1,53 @@ +{{ define "go_numbers_limit_validator" }} +{{- if . }} + {{- if and .Maximum .Minimum }} + if v > {{ .Maximum }} || v < {{ .Minimum }} { + return fmt.Errorf("value must be between {{ .Minimum }} and {{ .Maximum }}") + } + {{- else if .Minimum }} + if v < {{ .Minimum }} { + return fmt.Errorf("value must be greater than or equal to {{ .Minimum }}") + } + {{- else if .Maximum }} + if v > {{ .Maximum }} { + return fmt.Errorf("value must be less than or equal to {{ .Maximum }}") + } + {{- end }} +{{- end }} +{{ end }} + +{{ define "go_regex_validator" }} + {{- if . }} + if matched, err := regexp.MatchString(`{{ .Expression }}`, v); err != nil || !matched { + return fmt.Errorf("value must match {{ .Expression }}") + } + {{- end }} +{{ end }} + +{{ define "go_length_validator" }} + {{- if . }} + {{- if and .Maximum .Minimum }} + if len(v) > {{ .Maximum }} || len(v) < {{ .Minimum }} { + return fmt.Errorf("length must be between {{ .Minimum }} and {{ .Maximum }}") + } + {{- else if .Minimum }} + if len(v) < {{ .Minimum }} { + return fmt.Errorf("length must be greater than or equal to {{ .Minimum }}") + } + {{- else if .Maximum }} + if len(v) > {{ .Maximum }} { + return fmt.Errorf("length must be less than or equal to {{ .Maximum }}") + } + {{- end }} + {{- end }} +{{ end }} + +{{ define "go_case_modifier" }} + {{- if . }} + {{- if eq .Kind "upper" }} + v = strings.ToUpper(v) + {{- else if eq .Kind "lower" }} + v = strings.ToLower(v) + {{- end }} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/generated.txt b/extension/generator/rust/generated.txt new file mode 100644 index 00000000..6b4ea1b1 --- /dev/null +++ b/extension/generator/rust/generated.txt @@ -0,0 +1,331 @@ +// Code generated by scale-extension 0.4.1, DO NOT EDIT. +// output: types + +#![allow(dead_code)] +#![allow(unused_imports)] +#![allow(unused_variables)] +#![allow(unused_mut)] +use std::io::Cursor; +use polyglot_rs::{DecodingError, Encoder, Decoder, Kind}; +use num_enum::TryFromPrimitive; +use std::convert::TryFrom; +use std::collections::HashMap; +use regex::Regex; +pub trait Encode { + fn encode<'a>( + a: Option<&Self>, + b: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> + where + Self: Sized; +} +trait EncodeSelf { + fn encode_self<'a, 'b>( + &'b self, + b: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box>; +} +pub trait Decode { + fn decode( + b: &mut Cursor<&mut Vec>, + ) -> Result, Box> + where + Self: Sized; +} +#[derive(Clone, Debug, PartialEq)] +pub struct HttpConfig { + pub timeout: i32, +} +impl HttpConfig { + pub fn new() -> Self { + Self { timeout: 60 } + } +} +impl Encode for HttpConfig { + fn encode<'a>( + a: Option<&HttpConfig>, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + a.encode_self(e) + } +} +impl EncodeSelf for HttpConfig { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + e.encode_i32(self.timeout)?; + Ok(e) + } +} +impl EncodeSelf for Option<&HttpConfig> { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl EncodeSelf for Option { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl Decode for HttpConfig { + fn decode( + d: &mut Cursor<&mut Vec>, + ) -> Result, Box> { + if d.decode_none() { + return Ok(None); + } + if let Ok(error) = d.decode_error() { + return Err(error); + } + let mut x = HttpConfig::new(); + x.timeout = d.decode_i32()?; + Ok(Some(x)) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct HttpResponse { + pub headers: HashMap, + pub status_code: i32, + pub body: Vec, +} +impl HttpResponse { + pub fn new() -> Self { + Self { + headers: HashMap::new(), + status_code: 0, + body: Vec::with_capacity(0), + } + } +} +impl Encode for HttpResponse { + fn encode<'a>( + a: Option<&HttpResponse>, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + a.encode_self(e) + } +} +impl EncodeSelf for HttpResponse { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + e.encode_map(self.headers.len(), Kind::String, Kind::Any)?; + for (k, v) in &self.headers { + e.encode_string(&k)?; + v.encode_self(e)?; + } + e.encode_i32(self.status_code)?; + e.encode_bytes(&self.body)?; + Ok(e) + } +} +impl EncodeSelf for Option<&HttpResponse> { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl EncodeSelf for Option { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl Decode for HttpResponse { + fn decode( + d: &mut Cursor<&mut Vec>, + ) -> Result, Box> { + if d.decode_none() { + return Ok(None); + } + if let Ok(error) = d.decode_error() { + return Err(error); + } + let mut x = HttpResponse::new(); + let size_headers = d.decode_map(Kind::String, Kind::Any)?; + for _ in 0..size_headers { + let k = d.decode_string()?; + let v = StringList::decode(d)?.ok_or(DecodingError::InvalidMap)?; + x.headers.insert(k, v); + } + x.status_code = d.decode_i32()?; + x.body = d.decode_bytes()?; + Ok(Some(x)) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct StringList { + pub values: Vec, +} +impl StringList { + pub fn new() -> Self { + Self { + values: Vec::with_capacity(0), + } + } +} +impl Encode for StringList { + fn encode<'a>( + a: Option<&StringList>, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + a.encode_self(e) + } +} +impl EncodeSelf for StringList { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + e.encode_array(self.values.len(), Kind::String)?; + for a in &self.values { + e.encode_string(&a)?; + } + Ok(e) + } +} +impl EncodeSelf for Option<&StringList> { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl EncodeSelf for Option { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl Decode for StringList { + fn decode( + d: &mut Cursor<&mut Vec>, + ) -> Result, Box> { + if d.decode_none() { + return Ok(None); + } + if let Ok(error) = d.decode_error() { + return Err(error); + } + let mut x = StringList::new(); + let size_values = d.decode_array(Kind::String)?; + for _ in 0..size_values { + x.values.push(d.decode_string()?); + } + Ok(Some(x)) + } +} +#[derive(Clone, Debug, PartialEq)] +pub struct ConnectionDetails { + pub url: String, +} +impl ConnectionDetails { + pub fn new() -> Self { + Self { + url: "https://google.com".to_string(), + } + } +} +impl Encode for ConnectionDetails { + fn encode<'a>( + a: Option<&ConnectionDetails>, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + a.encode_self(e) + } +} +impl EncodeSelf for ConnectionDetails { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + e.encode_string(&self.url)?; + Ok(e) + } +} +impl EncodeSelf for Option<&ConnectionDetails> { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl EncodeSelf for Option { + fn encode_self<'a, 'b>( + &'b self, + e: &'a mut Cursor>, + ) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } +} +impl Decode for ConnectionDetails { + fn decode( + d: &mut Cursor<&mut Vec>, + ) -> Result, Box> { + if d.decode_none() { + return Ok(None); + } + if let Ok(error) = d.decode_error() { + return Err(error); + } + let mut x = ConnectionDetails::new(); + x.url = d.decode_string()?; + Ok(Some(x)) + } +} diff --git a/extension/generator/rust/generator.go b/extension/generator/rust/generator.go new file mode 100644 index 00000000..0e42aaab --- /dev/null +++ b/extension/generator/rust/generator.go @@ -0,0 +1,286 @@ +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package rust + +import ( + "bytes" + "context" + "strings" + "text/template" + + interfacesVersion "github.com/loopholelabs/scale-extension-interfaces/version" + + polyglotVersion "github.com/loopholelabs/polyglot/version" + + scaleVersion "github.com/loopholelabs/scale/version" + + polyglotUtils "github.com/loopholelabs/polyglot/utils" + + "github.com/loopholelabs/scale/extension" + "github.com/loopholelabs/scale/extension/generator/rust/templates" + "github.com/loopholelabs/scale/signature/generator/rust/format" + "github.com/loopholelabs/scale/signature/generator/utils" +) + +const ( + defaultPackageName = "types" +) + +var generator *Generator + +// GenerateTypes generates the types for the extension +func GenerateTypes(extensionSchema *extension.Schema, packageName string) ([]byte, error) { + return generator.GenerateTypes(extensionSchema, packageName) +} + +// GenerateCargofile generates the cargo.toml file for the extension +func GenerateCargofile(packageName string, packageVersion string) ([]byte, error) { + return generator.GenerateCargofile(packageName, packageVersion) +} + +func GenerateGuest(extensionSchema *extension.Schema, signatureHash string, packageName string) ([]byte, error) { + return generator.GenerateGuest(extensionSchema, signatureHash, packageName) +} + +func init() { + var err error + generator, err = New() + if err != nil { + panic(err) + } +} + +// Generator is the rust generator +type Generator struct { + templ *template.Template + formatter *format.Formatter +} + +// New creates a new rust generator +func New() (*Generator, error) { + templ, err := template.New("").Funcs(templateFunctions()).ParseFS(templates.FS, "*.rs.templ") + if err != nil { + return nil, err + } + + formatter, err := format.New() + if err != nil { + return nil, err + } + + return &Generator{ + templ: templ, + formatter: formatter, + }, nil +} + +// GenerateTypes generates the types for the extension +func (g *Generator) GenerateTypes(extensionSchema *extension.Schema, packageName string) ([]byte, error) { + + schema, err := extensionSchema.CloneWithDisabledAccessorsValidatorsAndModifiers() + if err != nil { + return nil, err + } + + if packageName == "" { + packageName = defaultPackageName + } + + buf := new(bytes.Buffer) + err = g.templ.ExecuteTemplate(buf, "types.rs.templ", map[string]any{ + "signature_schema": schema, + "package_name": packageName, + }) + if err != nil { + return nil, err + } + + formatted, err := g.formatter.Format(context.Background(), buf.String()) + if err != nil { + return nil, err + } + + buf.Reset() + err = g.templ.ExecuteTemplate(buf, "header.rs.templ", map[string]any{ + "generator_version": strings.Trim(scaleVersion.Version(), "v"), + "package_name": packageName, + }) + if err != nil { + return nil, err + } + return []byte(buf.String() + "\n\n" + formatted), nil +} + +// GenerateCargofile generates the cargofile for the extension +func (g *Generator) GenerateCargofile(packageName string, packageVersion string) ([]byte, error) { + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "cargo.rs.templ", map[string]any{ + "polyglot_version": strings.TrimPrefix(polyglotVersion.Version(), "v"), + "scale_signature_interfaces_version": strings.TrimPrefix(interfacesVersion.Version(), "v"), + "package_name": packageName, + "package_version": strings.TrimPrefix(packageVersion, "v"), + }) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// GenerateGuest generates the guest bindings +func (g *Generator) GenerateGuest(extensionSchema *extension.Schema, extensionHash string, packageName string) ([]byte, error) { + if packageName == "" { + packageName = defaultPackageName + } + + buf := new(bytes.Buffer) + err := g.templ.ExecuteTemplate(buf, "guest.rs.templ", map[string]any{ + "extension_schema": extensionSchema, + "extension_hash": extensionHash, + }) + if err != nil { + return nil, err + } + + formatted := buf.String() + /* + formatted, err := g.formatter.Format(context.Background(), buf.String()) + if err != nil { + return nil, err + } + */ + buf.Reset() + err = g.templ.ExecuteTemplate(buf, "header.rs.templ", map[string]any{ + "generator_version": strings.TrimPrefix(scaleVersion.Version(), "v"), + "package_name": packageName, + }) + if err != nil { + return nil, err + } + return []byte(buf.String() + "\n\n" + formatted), nil +} + +func templateFunctions() template.FuncMap { + return template.FuncMap{ + "Primitive": primitive, + "IsPrimitive": extension.ValidPrimitiveType, + "PolyglotPrimitive": polyglotPrimitive, + "PolyglotPrimitiveEncode": polyglotPrimitiveEncode, + "PolyglotPrimitiveDecode": polyglotPrimitiveDecode, + "Deref": func(i *bool) bool { return *i }, + "LowerFirst": func(s string) string { return string(s[0]+32) + s[1:] }, + "SnakeCase": polyglotUtils.SnakeCase, + "Params": utils.Params, + } +} + +func primitive(t string) string { + switch t { + case "string": + return "String" + case "int32": + return "i32" + case "int64": + return "i64" + case "uint32": + return "u32" + case "uint64": + return "u64" + case "float32": + return "f32" + case "float64": + return "f64" + case "bool": + return "bool" + case "bytes": + return "Vec" + default: + return t + } +} + +func polyglotPrimitive(t string) string { + switch t { + case "string": + return "Kind::String" + case "int32": + return "Kind::I32" + case "int64": + return "Kind::I64" + case "uint32": + return "Kind::U32" + case "uint64": + return "Kind::U64" + case "float32": + return "Kind::F32" + case "float64": + return "Kind::F64" + case "bool": + return "Kind::Bool" + case "bytes": + return "Kind::Bytes" + default: + return "Kind::Any" + } +} + +func polyglotPrimitiveEncode(t string) string { + switch t { + case "string": + return "encode_string" + case "int32": + return "encode_i32" + case "int64": + return "encode_i64" + case "uint32": + return "encode_u32" + case "uint64": + return "encode_u64" + case "float32": + return "encode_f32" + case "float64": + return "encode_f64" + case "bool": + return "encode_bool" + case "bytes": + return "encode_bytes" + default: + return t + } +} + +func polyglotPrimitiveDecode(t string) string { + switch t { + case "string": + return "decode_string" + case "int32": + return "decode_i32" + case "int64": + return "decode_i64" + case "uint32": + return "decode_u32" + case "uint64": + return "decode_u64" + case "float32": + return "decode_f32" + case "float64": + return "decode_f64" + case "bool": + return "decode_bool" + case "bytes": + return "decode_bytes" + default: + return "" + } +} diff --git a/extension/generator/rust/generator_test.go b/extension/generator/rust/generator_test.go new file mode 100644 index 00000000..681e4c86 --- /dev/null +++ b/extension/generator/rust/generator_test.go @@ -0,0 +1,54 @@ +//go:build !integration && !generate + +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package rust + +import ( + "encoding/hex" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/loopholelabs/scale/extension" +) + +func TestGenerator(t *testing.T) { + s := new(extension.Schema) + err := s.Decode([]byte(extension.MasterTestingSchema)) + require.NoError(t, err) + + formatted, err := GenerateTypes(s, "types") + require.NoError(t, err) + + os.WriteFile("./generated.txt", formatted, 0644) + /* + master, err := os.ReadFile("./generated.txt") + require.NoError(t, err) + require.Equal(t, string(master), string(formatted)) + */ + t.Log(string(formatted)) + + sHash, err := s.Hash() + h := hex.EncodeToString(sHash) + + guest, err := GenerateGuest(s, h, "guest") + require.NoError(t, err) + os.WriteFile("./guest.txt", guest, 0644) + expGuest, err := os.ReadFile("./guest.txt") + require.NoError(t, err) + require.Equal(t, string(expGuest), string(guest)) + +} diff --git a/extension/generator/rust/guest.txt b/extension/generator/rust/guest.txt new file mode 100644 index 00000000..a2d62cbb --- /dev/null +++ b/extension/generator/rust/guest.txt @@ -0,0 +1,81 @@ +// Code generated by scale-extension 0.4.1, DO NOT EDIT. +// output: guest + +pub mod types; +use crate::types::{Encode, Decode}; + +use std::io::Cursor; +use polyglot_rs::{Encoder}; + +static HASH: &'static str = "3914ee157703d809e20bf4e9f4a6d0cf0db287ec4b3dcfe4982c25b0101bc156"; + +static mut READ_BUFFER: Vec = Vec::new(); +static mut WRITE_BUFFER: Vec = Vec::new(); + +// resize resizes the extensions READ_BUFFER to the given size and returns the pointer to the buffer +// +// Users should not use this method. +#[export_name = "ext_HttpFetch_Resize"] +#[no_mangle] +pub unsafe fn ext_HttpFetch_Resize(size: u32) -> *const u8 { + READ_BUFFER.resize(size as usize, 0); + return READ_BUFFER.as_ptr(); +} + + + +// Define any interfaces we need here... +// Also define structs we can use to hold instanceId + + + +// Define concrete types with a hidden instanceId HttpConnector + +// type _HttpConnector struct { +// instanceId uint64 +// } + + +// func (d *_HttpConnector) Fetch(params *ConnectionDetails) (HttpResponse, error) { +// } + +//export ext_HttpFetch_HttpConnector_Fetch +//go:linkname ext_HttpFetch_HttpConnector_Fetch +//func ext_HttpFetch_HttpConnector_Fetch(instance uint64, offset uint32, length uint32) uint64 + + + + + +// Define any global functions here... + + + +//export ext_HttpFetch_New +//go:linkname ext_HttpFetch_New +//func ext_HttpFetch_New(instance uint64, offset uint32, length uint32) uint64 + +//func New(params *HttpConfig) (HttpConnector, error) { +//} + + + +// error serializes an error into the global WRITE_BUFFER and returns a pointer to the buffer and its size +// +// Users should not use this method. +pub unsafe fn error(error: Box) -> (u32, u32) { + let mut cursor = Cursor::new(Vec::new()); + return match cursor.encode_error(error) { + Ok(_) => { + let vec = cursor.into_inner(); + + WRITE_BUFFER.resize(vec.len() as usize, 0); + WRITE_BUFFER.copy_from_slice(&vec); + + (WRITE_BUFFER.as_ptr() as u32, WRITE_BUFFER.len() as u32) + } + Err(_) => { + (0, 0) + } + }; +} diff --git a/extension/generator/rust/templates/arrays.rs.templ b/extension/generator/rust/templates/arrays.rs.templ new file mode 100644 index 00000000..2d15e2ad --- /dev/null +++ b/extension/generator/rust/templates/arrays.rs.templ @@ -0,0 +1,37 @@ +{{ define "rs_arrays_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ SnakeCase .Name }}: Vec<{{ Primitive $type }}>, + {{- else }} + pub {{ SnakeCase .Name }}: Vec<{{ Primitive $type }}>, + {{- end -}} + {{- end }} +{{ end }} + +{{ define "rs_arrays_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{ SnakeCase .Name }}: Vec::with_capacity({{ .InitialSize }}), + {{- end }} +{{ end }} + +{{ define "rs_arrays_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + e.encode_array(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }})?; + for a in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(*a)?; + } + {{- end }} +{{ end }} + +{{ define "rs_arrays_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + let size_{{ SnakeCase .Name }} = d.decode_array({{ PolyglotPrimitive $type }})?; + for _ in 0..size_{{ SnakeCase .Name }} { + x.{{ SnakeCase .Name }}.push(d.{{ PolyglotPrimitiveDecode $type }}()?); + } + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/cargo.rs.templ b/extension/generator/rust/templates/cargo.rs.templ new file mode 100644 index 00000000..efb7fa5f --- /dev/null +++ b/extension/generator/rust/templates/cargo.rs.templ @@ -0,0 +1,24 @@ +[package] +edition = "2021" +name = "{{ .package_name }}" +version = "{{ .package_version }}" + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 + +[lib] +path = "guest.rs" + +[dependencies.num_enum] +version = "0.7.0" + +[dependencies.regex] +version = "1.9.4" + +[dependencies.scale_signature_interfaces] +version = "{{ .scale_signature_interfaces_version }}" + +[dependencies.polyglot_rs] +version = "{{ .polyglot_version }}" diff --git a/extension/generator/rust/templates/enumarrays.rs.templ b/extension/generator/rust/templates/enumarrays.rs.templ new file mode 100644 index 00000000..bdd88db9 --- /dev/null +++ b/extension/generator/rust/templates/enumarrays.rs.templ @@ -0,0 +1,36 @@ +{{ define "rs_enumarrays_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + {{- if (Deref .Accessor) }} + {{ SnakeCase .Name }}: Vec<{{ .Reference }}>, + {{- else }} + pub {{ SnakeCase .Name }}: Vec<{{ .Reference }}>, + {{- end -}} + {{- end }} +{{ end }} + +{{ define "rs_enumarrays_new_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + {{ SnakeCase .Name }}: Vec::with_capacity({{ .InitialSize }}), + {{- end }} +{{ end }} + +{{ define "rs_enumarrays_encode" }} + {{- range .EnumArrays }} + e.encode_array(self.{{ SnakeCase .Name}}.len(), Kind::U32)?; + for a in &self.{{ SnakeCase .Name}} { + e.encode_u32(*a as u32)?; + } + {{- end }} +{{ end }} + +{{ define "rs_enumarrays_decode" }} + {{ $current_model := . }} + {{- range .EnumArrays }} + let size_{{ SnakeCase .Name }} = d.decode_array(Kind::U32)?; + for _ in 0..size_{{ SnakeCase .Name }} { + x.{{ SnakeCase .Name }}.push({{ .Reference }}::try_from(d.decode_u32()?)?); + } + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/enummaps.rs.templ b/extension/generator/rust/templates/enummaps.rs.templ new file mode 100644 index 00000000..ace8824b --- /dev/null +++ b/extension/generator/rust/templates/enummaps.rs.templ @@ -0,0 +1,70 @@ +{{ define "rs_enummaps_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{- if and (Deref .Accessor) (IsPrimitive .Value) }} + {{ SnakeCase .Name }}: HashMap<{ .Reference }}, {{ Primitive .Value }}>, + {{- end }} + + {{- if and (Deref .Accessor) (not (IsPrimitive .Value)) }} + {{ SnakeCase .Name }}: HashMap<{{ .Reference }}, {{ .Value }}>, + {{- end }} + + {{- if and (not (Deref .Accessor)) (IsPrimitive .Value) }} + pub {{ SnakeCase .Name }}: HashMap<{{ .Reference }}, {{ Primitive .Value }}>, + {{- end }} + + {{- if and (not (Deref .Accessor)) (not (IsPrimitive .Value)) }} + pub {{ SnakeCase .Name }}: HashMap<{{ .Reference }}, {{ .Value }}>, + {{- end }} + {{- end }} +{{ end }} + +{{ define "rs_enummaps_new_struct_reference" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{ SnakeCase .Name }}: HashMap::new(), + {{ end }} +{{ end }} + +{{ define "rs_enummaps_encode" }} + {{- range .EnumMaps }} + {{- if IsPrimitive .Value }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), Kind::U32, {{ PolyglotPrimitive .Value }})?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.encode_u32(*k as u32)?; + {{- if eq .Value "string"}} + e.{{ PolyglotPrimitiveEncode .Value }}(&v)?; + {{- else }} + e.{{ PolyglotPrimitiveEncode .Value }}(v)?; + {{- end }} + } + {{- else }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), Kind::U32, Kind::Any)?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.encode_u32(*k as u32)?; + v.encode_self(e)?; + } + {{- end }} + {{- end }} +{{ end }} + +{{ define "rs_enummaps_decode" }} + {{ $current_model := . }} + {{- range .EnumMaps }} + {{- if IsPrimitive .Value }} + let size_{{ SnakeCase .Name }} = d.decode_map(Kind::U32, {{ PolyglotPrimitive .Value }})?; + for _ in 0..size_{{ SnakeCase .Name }} { + let k = {{ .Reference }}::try_from(d.decode_u32()?)?; + let v = d.{{ PolyglotPrimitiveDecode .Value }}()?; + x.{{ SnakeCase .Name }}.insert(k, v); + } + {{- else }} + let size_{{ SnakeCase .Name }} = d.decode_map(Kind::U32, Kind::Any)?; + for _ in 0..size_{{ SnakeCase .Name }} { + let k = {{ .Reference }}::try_from(d.decode_u32()?)?; + let v = {{ .Value }}::decode(d)?.ok_or(DecodingError::InvalidMap)?; + x.{{ SnakeCase .Name }}.insert(k, v); + } + {{- end }} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/enums.rs.templ b/extension/generator/rust/templates/enums.rs.templ new file mode 100644 index 00000000..800efa56 --- /dev/null +++ b/extension/generator/rust/templates/enums.rs.templ @@ -0,0 +1,58 @@ +{{ define "rs_enums_definition" }} + #[derive(Debug, Eq, PartialEq, TryFromPrimitive, Copy, Clone, Hash)] + #[repr(u32)] + pub enum {{ .Name }} { + {{- range $index, $value := .Values }} + {{ $value }} = {{ $index }}, + {{- end }} + } +{{ end }} + +{{ define "rs_enums_struct_reference" }} + {{ $current_model := . }} + {{- range .Enums }} + {{- if (Deref .Accessor) }} + {{ SnakeCase .Name }}: {{ .Reference }}, + {{- else }} + pub {{ SnakeCase .Name }}: {{ .Reference }}, + {{- end -}} + {{ end }} +{{ end }} + +{{ define "rs_enums_new_struct_reference" }} + {{ $current_model := . }} + {{- range .Enums }} + {{ SnakeCase .Name }}: {{ .Reference }}::{{ .Default }}, + {{ end }} +{{ end }} + +{{ define "rs_enums_encode" }} + {{ $current_model := . }} + {{- range .Enums }} + e.encode_u32(self.{{ SnakeCase .Name }} as u32)?; + {{- end }} +{{ end }} + +{{ define "rs_enums_decode" }} + {{ $current_model := . }} + {{- range .Enums }} + x.{{ SnakeCase .Name }} = {{ .Reference }}::try_from(d.decode_u32()?).ok().ok_or(DecodingError::InvalidEnum)?; + {{- end }} +{{ end }} + +{{ define "rs_enums_accessor" }} + {{ $current_model := . }} + {{- range .Enums }} + {{- if .Accessor }} + impl {{ $current_model.Name }} { + pub fn get_{{ SnakeCase .Name }}(&self) -> &{{ .Reference }} { + &self.{{ SnakeCase .Name }} + } + + pub fn set_{{ SnakeCase .Name }}(&mut self, v: {{ .Reference }}) { + self.{{ SnakeCase .Name }} = v; + } + } + {{- end -}} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/guest.rs.templ b/extension/generator/rust/templates/guest.rs.templ new file mode 100644 index 00000000..e4421f25 --- /dev/null +++ b/extension/generator/rust/templates/guest.rs.templ @@ -0,0 +1,78 @@ +pub mod types; +use crate::types::{Encode, Decode}; + +use std::io::Cursor; +use polyglot_rs::{Encoder}; + +static HASH: &'static str = "{{ .extension_hash }}"; + +static mut READ_BUFFER: Vec = Vec::new(); +static mut WRITE_BUFFER: Vec = Vec::new(); + +// resize resizes the extensions READ_BUFFER to the given size and returns the pointer to the buffer +// +// Users should not use this method. +#[export_name = "ext_{{ .extension_schema.Name }}_Resize"] +#[no_mangle] +pub unsafe fn ext_{{ .extension_schema.Name }}_Resize(size: u32) -> *const u8 { + READ_BUFFER.resize(size as usize, 0); + return READ_BUFFER.as_ptr(); +} + +{{ $schema := .extension_schema }} + +// Define any interfaces we need here... +// Also define structs we can use to hold instanceId + +{{ range $ifc := .extension_schema.Interfaces }} + +// Define concrete types with a hidden instanceId {{ $ifc.Name }} + +// type _{{ $ifc.Name }} struct { +// instanceId uint64 +// } + +{{ range $fn := $ifc.Functions }} +// func (d *_{{ $ifc.Name }}) {{ $fn.Name }}(params *{{ $fn.Params }}) ({{ $fn.Return }}, error) { +// } + +//export ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }} +//go:linkname ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }} +//func ext_{{ $schema.Name }}_{{ $ifc.Name }}_{{ $fn.Name }}(instance uint64, offset uint32, length uint32) uint64 + +{{ end }} + +{{ end }} + +// Define any global functions here... + +{{ range $fn := .extension_schema.Functions }} + +//export ext_{{ $schema.Name }}_{{ $fn.Name }} +//go:linkname ext_{{ $schema.Name }}_{{ $fn.Name }} +//func ext_{{ $schema.Name }}_{{ $fn.Name }}(instance uint64, offset uint32, length uint32) uint64 + +//func {{ $fn.Name }}(params *{{ $fn.Params }}) ({{ $fn.Return }}, error) { +//} + +{{ end }} + +// error serializes an error into the global WRITE_BUFFER and returns a pointer to the buffer and its size +// +// Users should not use this method. +pub unsafe fn error(error: Box) -> (u32, u32) { + let mut cursor = Cursor::new(Vec::new()); + return match cursor.encode_error(error) { + Ok(_) => { + let vec = cursor.into_inner(); + + WRITE_BUFFER.resize(vec.len() as usize, 0); + WRITE_BUFFER.copy_from_slice(&vec); + + (WRITE_BUFFER.as_ptr() as u32, WRITE_BUFFER.len() as u32) + } + Err(_) => { + (0, 0) + } + }; +} diff --git a/extension/generator/rust/templates/header.rs.templ b/extension/generator/rust/templates/header.rs.templ new file mode 100644 index 00000000..1c34c88c --- /dev/null +++ b/extension/generator/rust/templates/header.rs.templ @@ -0,0 +1,2 @@ +// Code generated by scale-extension {{ .generator_version }}, DO NOT EDIT. +// output: {{ .package_name }} \ No newline at end of file diff --git a/extension/generator/rust/templates/maps.rs.templ b/extension/generator/rust/templates/maps.rs.templ new file mode 100644 index 00000000..d8be9eda --- /dev/null +++ b/extension/generator/rust/templates/maps.rs.templ @@ -0,0 +1,58 @@ +{{ define "rs_maps_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ SnakeCase .Name }}: HashMap<{{ Primitive $type }}, {{ Primitive .Value }}>, + {{- else }} + pub {{ SnakeCase .Name }}: HashMap<{{ Primitive $type }}, {{ Primitive .Value }}>, + {{- end }} + {{ end }} +{{ end }} + +{{ define "rs_maps_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{ SnakeCase .Name }}: HashMap::new(), + {{- end }} +{{ end }} + +{{ define "rs_maps_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if IsPrimitive .Value }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }})?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(*k)?; + e.{{ PolyglotPrimitiveEncode .Value }}(*v)?; + } + {{- else }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }}, Kind::Any)?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(*k)?; + v.encode_self(e)?; + } + {{- end }} + + {{- end }} +{{ end }} + +{{ define "rs_maps_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if IsPrimitive .Value }} + let size_{{ SnakeCase .Name }} = d.decode_map({{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }})?; + for _ in 0..size_{{ SnakeCase .Name }} { + let k = d.{{ PolyglotPrimitiveDecode $type }}()?; + let v = d.{{ PolyglotPrimitiveDecode .Value }}()?; + x.{{ SnakeCase .Name }}.insert(k, v); + } + {{- else }} + let size_{{ SnakeCase .Name }} = d.decode_map({{ PolyglotPrimitive $type }}, Kind::Any)?; + for _ in 0..size_{{ SnakeCase .Name }} { + let k = d.{{ PolyglotPrimitiveDecode $type }}()?; + let v = {{ .Value }}::decode(d)?.ok_or(DecodingError::InvalidMap)?; + x.{{ SnakeCase .Name }}.insert(k, v); + } + {{- end }} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/modelarrays.rs.templ b/extension/generator/rust/templates/modelarrays.rs.templ new file mode 100644 index 00000000..af3746aa --- /dev/null +++ b/extension/generator/rust/templates/modelarrays.rs.templ @@ -0,0 +1,50 @@ +{{ define "rs_modelarrays_struct_reference" }} + {{- range .ModelArrays }} + {{- if .Accessor }} + {{ SnakeCase .Name }}: Vec<{{ .Reference }}>, + {{- else }} + pub {{ SnakeCase .Name }}: Vec<{{ .Reference }}>, + {{- end -}} + {{ end }} +{{ end }} + +{{ define "rs_modelarrays_new_struct_reference" }} + {{- range .ModelArrays }} + {{ SnakeCase .Name }}: Vec::with_capacity({{ .InitialSize }}), + {{- end }} +{{ end }} + +{{ define "rs_modelarrays_encode" }} + {{- range .ModelArrays }} + e.encode_array(self.{{ SnakeCase .Name }}.len(), Kind::Any)?; + for a in &self.{{ SnakeCase .Name }} { + a.encode_self(e)?; + } + {{- end }} +{{ end }} + +{{ define "rs_modelarrays_decode" }} + {{- range .ModelArrays }} + let size_{{ SnakeCase .Name }} = d.decode_array(Kind::Any)?; + for _ in 0..size_{{ SnakeCase .Name }} { + x.{{ SnakeCase .Name }}.push({{ .Reference }}::decode(d)?.ok_or(DecodingError::InvalidArray)?); + } + {{- end }} +{{ end }} + +{{ define "rs_modelarrays_accessor" }} + {{ $current_model := . }} + {{- range .ModelArrays }} + {{- if .Accessor }} + impl {{ $current_model.Name }} { + pub fn get_{{ SnakeCase .Name }} (&self) -> Option<&Vec<{{ .Reference }}>> { + Some(&self.{{ SnakeCase .Name }}) + } + + pub fn set_{{ SnakeCase .Name }} (&mut self, v: Vec<{{ .Reference }}>) { + self.{{ SnakeCase .Name }} = v; + } + } + {{- end -}} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/models.rs.templ b/extension/generator/rust/templates/models.rs.templ new file mode 100644 index 00000000..c1334abd --- /dev/null +++ b/extension/generator/rust/templates/models.rs.templ @@ -0,0 +1,44 @@ +{{ define "rs_models_struct_reference" }} + {{- range .Models }} + {{- if .Accessor }} + {{ SnakeCase .Name }}: Option<{{ .Reference }}>, + {{- else }} + pub {{ SnakeCase .Name }}: Option<{{ .Reference }}>, + {{- end -}} + {{- end }} +{{ end }} + +{{ define "rs_models_new_struct_reference" }} + {{- range .Models }} + {{ SnakeCase .Name }}: Some({{ .Reference }}::new()), + {{- end }} +{{ end }} + +{{ define "rs_models_encode" }} + {{- range .Models }} + self.{{ SnakeCase .Name }}.encode_self(e)?; + {{- end }} +{{ end }} + +{{ define "rs_models_decode" }} + {{- range .Models }} + x.{{ SnakeCase .Name }} = {{ .Reference }}::decode(d)?; + {{- end }} +{{ end }} + +{{ define "rs_models_accessor" }} + {{ $current_model := . }} + {{- range .Models }} + {{- if .Accessor }} + impl {{ $current_model.Name }} { + pub fn get_{{ SnakeCase .Name }}(&self) -> &Option<{{ .Reference }}> { + &self.{{ SnakeCase .Name }} + } + + pub fn set_{{ SnakeCase .Name }}(&mut self, v: Option<{{ .Reference }}>) { + self.{{ SnakeCase .Name }} = v; + } + } + {{- end -}} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/primitives.rs.templ b/extension/generator/rust/templates/primitives.rs.templ new file mode 100644 index 00000000..6d6d4c73 --- /dev/null +++ b/extension/generator/rust/templates/primitives.rs.templ @@ -0,0 +1,90 @@ +{{ define "rs_primitives_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + {{ SnakeCase .Name }}: {{ Primitive $type }}, + {{- else }} + pub {{ SnakeCase .Name }}: {{ Primitive $type }}, + {{- end -}} + {{ end }} +{{ end }} + +{{ define "rs_primitives_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{ SnakeCase .Name }}: {{ .Default }}, + {{ end }} +{{ end }} + +{{ define "rs_strings_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{ SnakeCase .Name }}: "{{ .Default }}".to_string(), + {{ end }} +{{ end }} + +{{ define "rs_bytes_new_struct_reference" }} + {{ $type := .Type }} + {{- range .Entries }} + {{ SnakeCase .Name }}: Vec::with_capacity({{ .InitialSize }}), + {{ end }} +{{ end }} + +{{ define "rs_primitives_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + e.{{ PolyglotPrimitiveEncode $type }}(self.{{ SnakeCase .Name }})?; + {{- end }} +{{ end}} + +{{ define "rs_ref_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + e.{{ PolyglotPrimitiveEncode $type }}(&self.{{ SnakeCase .Name }})?; + {{- end }} +{{ end}} + +{{ define "rs_primitives_decode" }} + {{ $type := .Type }} + {{- range .Entries }} + x.{{ SnakeCase .Name }} = d.{{ PolyglotPrimitiveDecode $type }}()?; + {{- end }} +{{ end}} + +{{ define "rs_numbers_accessor" }} + {{ $type := .Type }} + {{ $model := .Model }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + pub fn get_{{ SnakeCase .Name }}(&self) -> {{ Primitive $type }} { + self.{{ SnakeCase .Name }} + } + + pub fn set_{{ SnakeCase .Name }}(&mut self, v: {{ Primitive $type }}) -> Result<(), Box> { + {{- template "rs_numbers_limit_validator" .LimitValidator }} + self.{{ SnakeCase .Name }} = v; + Ok(()) + } + {{- end -}} + {{ end }} +{{ end }} + +{{ define "rs_strings_accessor" }} + {{ $type := .Type }} + {{ $model := .Model }} + {{- range .Entries }} + {{- if (Deref .Accessor) }} + pub fn get_{{ SnakeCase .Name }}(&self) -> {{ Primitive $type }} { + self.{{ SnakeCase .Name }}.clone() + } + + pub fn set_{{ SnakeCase .Name }}(&mut self, mut v: {{ Primitive $type }}) -> Result<(), Box> { + {{- template "rs_regex_validator" .RegexValidator }} + {{- template "rs_length_validator" .LengthValidator }} + {{- template "rs_case_modifier" .CaseModifier }} + self.{{ SnakeCase .Name }} = v; + Ok(()) + } + {{- end -}} + {{ end }} +{{ end }} diff --git a/extension/generator/rust/templates/refarrays.rs.templ b/extension/generator/rust/templates/refarrays.rs.templ new file mode 100644 index 00000000..bfa6067b --- /dev/null +++ b/extension/generator/rust/templates/refarrays.rs.templ @@ -0,0 +1,9 @@ +{{ define "rs_refarrays_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + e.encode_array(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }})?; + for a in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(&a)?; + } + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/refmaps.rs.templ b/extension/generator/rust/templates/refmaps.rs.templ new file mode 100644 index 00000000..cb53ca9f --- /dev/null +++ b/extension/generator/rust/templates/refmaps.rs.templ @@ -0,0 +1,19 @@ +{{ define "rs_refmaps_encode" }} + {{ $type := .Type }} + {{- range .Entries }} + {{- if IsPrimitive .Value }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }}, {{ PolyglotPrimitive .Value }})?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(&k)?; + e.{{ PolyglotPrimitiveEncode .Value }}(&v)?; + } + {{- else }} + e.encode_map(self.{{ SnakeCase .Name }}.len(), {{ PolyglotPrimitive $type }}, Kind::Any)?; + for (k, v) in &self.{{ SnakeCase .Name }} { + e.{{ PolyglotPrimitiveEncode $type }}(&k)?; + v.encode_self(e)?; + } + {{- end }} + + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/generator/rust/templates/templates.go b/extension/generator/rust/templates/templates.go new file mode 100644 index 00000000..fdfb581a --- /dev/null +++ b/extension/generator/rust/templates/templates.go @@ -0,0 +1,19 @@ +/* + Copyright 2023 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package templates + +import "embed" + +//go:embed * +var FS embed.FS diff --git a/extension/generator/rust/templates/types.rs.templ b/extension/generator/rust/templates/types.rs.templ new file mode 100644 index 00000000..edb108dc --- /dev/null +++ b/extension/generator/rust/templates/types.rs.templ @@ -0,0 +1,262 @@ +#![allow(dead_code)] +#![allow(unused_imports)] +#![allow(unused_variables)] +#![allow(unused_mut)] + +use std::io::Cursor; +use polyglot_rs::{DecodingError, Encoder, Decoder, Kind}; +use num_enum::TryFromPrimitive; +use std::convert::TryFrom; +use std::collections::HashMap; +use regex::Regex; + +pub trait Encode { + fn encode<'a> (a: Option<&Self>, b: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box> where Self: Sized; +} + +trait EncodeSelf { + fn encode_self<'a, 'b> (&'b self, b: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box>; +} + +pub trait Decode { + fn decode (b: &mut Cursor<&mut Vec>) -> Result, Box> where Self: Sized; +} + +{{- range .signature_schema.Enums }} + {{ template "rs_enums_definition" . }} +{{- end }} + +{{- range .signature_schema.Models -}} + {{- if .Description }} + // {{ .Name }}: {{ .Description }} + {{ end -}} + + #[derive(Clone, Debug, PartialEq)] + pub struct {{ .Name }} { + {{ template "rs_models_struct_reference" . }} + {{ template "rs_modelarrays_struct_reference" . }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Strings "Type" "string" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .StringArrays "Type" "string" }} + {{ template "rs_maps_struct_reference" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Int32s "Type" "int32" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "rs_maps_struct_reference" Params "Entries" .Int32Maps "Type" "int32" }} + + + {{ template "rs_primitives_struct_reference" Params "Entries" .Int64s "Type" "int64" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "rs_maps_struct_reference" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "rs_maps_struct_reference" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "rs_maps_struct_reference" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Float32s "Type" "float32" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Float64s "Type" "float64" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "rs_enums_struct_reference" . }} + {{ template "rs_enumarrays_struct_reference" . }} + {{ template "rs_enummaps_struct_reference" . }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "rs_primitives_struct_reference" Params "Entries" .Bools "Type" "bool" }} + {{ template "rs_arrays_struct_reference" Params "Entries" .BoolArrays "Type" "bool" }} + } + + impl {{ .Name }} { + pub fn new () -> Self { + Self { + {{ template "rs_models_new_struct_reference" . }} + {{ template "rs_modelarrays_new_struct_reference" . }} + + {{ template "rs_strings_new_struct_reference" Params "Entries" .Strings }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .StringArrays "Type" "string" }} + {{ template "rs_maps_new_struct_reference" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Int32s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "rs_maps_new_struct_reference" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Int64s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "rs_maps_new_struct_reference" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Uint32s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "rs_maps_new_struct_reference" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Uint64s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "rs_maps_new_struct_reference" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Float32s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Float64s }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "rs_enums_new_struct_reference" . }} + {{ template "rs_enumarrays_new_struct_reference" . }} + {{ template "rs_enummaps_new_struct_reference" . }} + + {{ template "rs_bytes_new_struct_reference" Params "Entries" .Bytes }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "rs_primitives_new_struct_reference" Params "Entries" .Bools }} + {{ template "rs_arrays_new_struct_reference" Params "Entries" .BoolArrays "Type" "bool" }} + } + } + + + {{ template "rs_strings_accessor" Params "Model" . "Entries" .Strings "Type" "string" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Int32s "Type" "int32" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Int64s "Type" "int64" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Uint32s "Type" "uint32" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Uint64s "Type" "uint64" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Float32s "Type" "float32" }} + {{ template "rs_numbers_accessor" Params "Model" . "Entries" .Float64s "Type" "float32" }} + } + + impl Encode for {{ .Name }} { + fn encode<'a> (a: Option<&{{ .Name }}>, e: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box> { + a.encode_self(e) + } + } + + impl EncodeSelf for {{ .Name }} { + fn encode_self<'a, 'b> (&'b self, e: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box> { + {{ template "rs_models_encode" . }} + {{ template "rs_modelarrays_encode" . }} + + {{ template "rs_ref_encode" Params "Entries" .Strings "Type" "string" }} + {{ template "rs_refarrays_encode" Params "Entries" .StringArrays "Type" "string" }} + {{ template "rs_refmaps_encode" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "rs_primitives_encode" Params "Entries" .Int32s "Type" "int32" }} + {{ template "rs_arrays_encode" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "rs_maps_encode" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "rs_primitives_encode" Params "Entries" .Int64s "Type" "int64" }} + {{ template "rs_arrays_encode" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "rs_maps_encode" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "rs_primitives_encode" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "rs_arrays_encode" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "rs_maps_encode" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "rs_primitives_encode" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "rs_arrays_encode" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "rs_maps_encode" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "rs_primitives_encode" Params "Entries" .Float32s "Type" "float32" }} + {{ template "rs_arrays_encode" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "rs_primitives_encode" Params "Entries" .Float64s "Type" "float64" }} + {{ template "rs_arrays_encode" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "rs_enums_encode" . }} + {{ template "rs_enumarrays_encode" . }} + {{ template "rs_enummaps_encode" . }} + + {{ template "rs_ref_encode" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "rs_refarrays_encode" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "rs_primitives_encode" Params "Entries" .Bools "Type" "bool" }} + {{ template "rs_arrays_encode" Params "Entries" .BoolArrays "Type" "bool" }} + + Ok(e) + } + } + + impl EncodeSelf for Option<&{{ .Name }}> { + fn encode_self<'a, 'b> (&'b self, e: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } + } + + impl EncodeSelf for Option<{{ .Name }}> { + fn encode_self<'a, 'b> (&'b self, e: &'a mut Cursor>) -> Result<&'a mut Cursor>, Box> { + if let Some(x) = self { + x.encode_self(e)?; + } else { + e.encode_none()?; + } + Ok(e) + } + } + + impl Decode for {{ .Name }} { + fn decode (d: &mut Cursor<&mut Vec>) -> Result, Box> { + if d.decode_none() { + return Ok(None); + } + + if let Ok(error) = d.decode_error() { + return Err(error); + } + + let mut x = {{ .Name }}::new(); + + {{ template "rs_models_decode" . }} + {{ template "rs_modelarrays_decode" . }} + + {{ template "rs_primitives_decode" Params "Entries" .Strings "Type" "string" }} + {{ template "rs_arrays_decode" Params "Entries" .StringArrays "Type" "string" }} + {{ template "rs_maps_decode" Params "Entries" .StringMaps "Type" "string" }} + + {{ template "rs_primitives_decode" Params "Entries" .Int32s "Type" "int32" }} + {{ template "rs_arrays_decode" Params "Entries" .Int32Arrays "Type" "int32" }} + {{ template "rs_maps_decode" Params "Entries" .Int32Maps "Type" "int32" }} + + {{ template "rs_primitives_decode" Params "Entries" .Int64s "Type" "int64" }} + {{ template "rs_arrays_decode" Params "Entries" .Int64Arrays "Type" "int64" }} + {{ template "rs_maps_decode" Params "Entries" .Int64Maps "Type" "int64" }} + + {{ template "rs_primitives_decode" Params "Entries" .Uint32s "Type" "uint32" }} + {{ template "rs_arrays_decode" Params "Entries" .Uint32Arrays "Type" "uint32" }} + {{ template "rs_maps_decode" Params "Entries" .Uint32Maps "Type" "uint32" }} + + {{ template "rs_primitives_decode" Params "Entries" .Uint64s "Type" "uint64" }} + {{ template "rs_arrays_decode" Params "Entries" .Uint64Arrays "Type" "uint64" }} + {{ template "rs_maps_decode" Params "Entries" .Uint64Maps "Type" "uint64" }} + + {{ template "rs_primitives_decode" Params "Entries" .Float32s "Type" "float32" }} + {{ template "rs_arrays_decode" Params "Entries" .Float32Arrays "Type" "float32" }} + + {{ template "rs_primitives_decode" Params "Entries" .Float64s "Type" "float64" }} + {{ template "rs_arrays_decode" Params "Entries" .Float64Arrays "Type" "float64" }} + + {{ template "rs_enums_decode" . }} + {{ template "rs_enumarrays_decode" . }} + {{ template "rs_enummaps_decode" . }} + + {{ template "rs_primitives_decode" Params "Entries" .Bytes "Type" "bytes" }} + {{ template "rs_arrays_decode" Params "Entries" .BytesArrays "Type" "bytes" }} + + {{ template "rs_primitives_decode" Params "Entries" .Bools "Type" "bool" }} + {{ template "rs_arrays_decode" Params "Entries" .BoolArrays "Type" "bool" }} + + Ok(Some(x)) + } + } + + {{ template "rs_models_accessor" . }} + {{ template "rs_modelarrays_accessor" . }} + {{ template "rs_enums_accessor" . }} +{{ end -}} \ No newline at end of file diff --git a/extension/generator/rust/templates/validators.rs.templ b/extension/generator/rust/templates/validators.rs.templ new file mode 100644 index 00000000..5308a27d --- /dev/null +++ b/extension/generator/rust/templates/validators.rs.templ @@ -0,0 +1,53 @@ +{{ define "rs_numbers_limit_validator" }} +{{- if . }} + {{- if and .Maximum .Minimum }} + if v > {{ .Maximum }} || v < {{ .Minimum }} { + return Err(Box::::from("value must be between { .Minimum }} and {{ .Maximum }}")); + } + {{- else if .Minimum }} + if v < {{ .Minimum }} { + return Err(Box::::from("value must be greater than or equal to {{ .Minimum }}")); + } + {{- else if .Maximum }} + if v > {{ .Maximum }} { + return Err(Box::::from("value must be less than or equal to {{ .Maximum }}")); + } + {{- end }} +{{- end }} +{{ end }} + +{{ define "rs_regex_validator" }} + {{- if . }} + if !Regex::new("^[a-zA-Z0-9]*$")?.is_match(v.as_str()) { + return Err(Box::::from("value must match {{ .Expression }}")); + } + {{- end }} +{{ end }} + +{{ define "rs_length_validator" }} + {{- if . }} + {{- if and .Maximum .Minimum }} + if v.len() > {{ .Maximum }} || v.len() < {{ .Minimum }} { + return Err(Box::::from("value must be between { .Minimum }} and {{ .Maximum }}")); + } + {{- else if .Minimum }} + if v.len() < {{ .Minimum }} { + return Err(Box::::from("value must be greater than or equal to {{ .Minimum }}")); + } + {{- else if .Maximum }} + if v.len() > {{ .Maximum }} { + return Err(Box::::from("value must be less than or equal to {{ .Maximum }}")); + } + {{- end }} + {{- end }} +{{ end }} + +{{ define "rs_case_modifier" }} + {{- if . }} + {{- if eq .Kind "upper" }} + v = v.to_uppercase(); + {{- else if eq .Kind "lower" }} + v = v.to_lowercase(); + {{- end }} + {{- end }} +{{ end }} \ No newline at end of file diff --git a/extension/info.go b/extension/info.go new file mode 100644 index 00000000..a91c4817 --- /dev/null +++ b/extension/info.go @@ -0,0 +1,7 @@ +package extension + +type ExtensionInfo struct { + Name string + Path string + Version string +} diff --git a/extension/interface.go b/extension/interface.go new file mode 100644 index 00000000..f9b2f0df --- /dev/null +++ b/extension/interface.go @@ -0,0 +1,23 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package extension + +type InterfaceSchema struct { + Name string `hcl:"name,label"` + Description string `hcl:"description,optional"` + Functions []*FunctionSchema `hcl:"function,block"` +} diff --git a/extension/schema.go b/extension/schema.go new file mode 100644 index 00000000..2c1f3a23 --- /dev/null +++ b/extension/schema.go @@ -0,0 +1,611 @@ +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package extension + +import ( + "crypto/sha256" + "errors" + "fmt" + "os" + "regexp" + + "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/gohcl" + "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/hashicorp/hcl/v2/hclwrite" + "github.com/loopholelabs/scale/signature" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +const ( + V1AlphaVersion = "v1alpha" +) + +var ( + ErrInvalidName = errors.New("invalid name") + ErrInvalidFunctionName = errors.New("invalid function name") + ErrInvalidTag = errors.New("invalid tag") + ErrNoInstanceId = errors.New("Extension has no int32 InstanceId defined") +) + +var ( + ValidLabel = regexp.MustCompile(`^[A-Za-z0-9]*$`) + InvalidString = regexp.MustCompile(`[^A-Za-z0-9-.]`) +) + +var ( + TitleCaser = cases.Title(language.Und, cases.NoLower) +) + +type Schema struct { + Version string `hcl:"version,attr"` + Name string `hcl:"name,attr"` + Tag string `hcl:"tag,attr"` + Interfaces []*InterfaceSchema `hcl:"interface,block"` + Functions []*FunctionSchema `hcl:"function,block"` + Enums []*signature.EnumSchema `hcl:"enum,block"` + Models []*signature.ModelSchema `hcl:"model,block"` + hasLimitValidator bool + hasLengthValidator bool + hasRegexValidator bool + hasCaseModifier bool +} + +func ReadSchema(path string) (*Schema, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read schema file: %w", err) + } + + s := new(Schema) + return s, s.Decode(data) +} + +func (s *Schema) Decode(data []byte) error { + file, diag := hclsyntax.ParseConfig(data, "", hcl.Pos{Line: 1, Column: 1}) + if diag.HasErrors() { + return diag.Errs()[0] + } + + diag = gohcl.DecodeBody(file.Body, nil, s) + if diag.HasErrors() { + return diag.Errs()[0] + } + + return nil +} + +func (s *Schema) Encode() ([]byte, error) { + f := hclwrite.NewEmptyFile() + gohcl.EncodeIntoBody(s, f.Body()) + return f.Bytes(), nil +} + +func (s *Schema) Validate() error { + switch s.Version { + case V1AlphaVersion: + if !ValidLabel.MatchString(s.Name) { + return ErrInvalidName + } + + if InvalidString.MatchString(s.Tag) { + return ErrInvalidTag + } + + // Transform all model names and references to TitleCase (e.g. "myModel" -> "MyModel") + for _, model := range s.Models { + model.Normalize() + } + + // Transform all model names and references to TitleCase (e.g. "myModel" -> "MyModel") + for _, enum := range s.Enums { + enum.Normalize() + } + + // Validate all models + knownModels := make(map[string]struct{}) + for _, model := range s.Models { + err := model.Validate(knownModels, s.Enums) + if err != nil { + return err + } + } + + // Validate all enums + knownEnums := make(map[string]struct{}) + for _, enum := range s.Enums { + err := enum.Validate(knownEnums) + if err != nil { + return err + } + } + + // Ensure all model and enum references are valid + for _, model := range s.Models { + for _, modelReference := range model.Models { + if _, ok := knownModels[modelReference.Reference]; !ok { + return fmt.Errorf("unknown %s.%s.reference: %s", model.Name, modelReference.Name, modelReference.Reference) + } + } + + for _, modelReferenceArray := range model.ModelArrays { + if _, ok := knownModels[modelReferenceArray.Reference]; !ok { + return fmt.Errorf("unknown %s.%s.reference: %s", model.Name, modelReferenceArray.Name, modelReferenceArray.Reference) + } + } + + for _, str := range model.Strings { + if str.LengthValidator != nil { + s.hasLengthValidator = true + } + if str.RegexValidator != nil { + s.hasRegexValidator = true + } + if str.CaseModifier != nil { + s.hasCaseModifier = true + } + } + + for _, strMap := range model.StringMaps { + if !signature.ValidPrimitiveType(strMap.Value) { + if _, ok := knownModels[strMap.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, strMap.Name, strMap.Value) + } + } + } + + for _, i32 := range model.Int32s { + if i32.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, i32Map := range model.Int32Maps { + if !signature.ValidPrimitiveType(i32Map.Value) { + if _, ok := knownModels[i32Map.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, i32Map.Name, i32Map.Value) + } + } + } + + for _, i64 := range model.Int64s { + if i64.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, i64Map := range model.Int64Maps { + if !signature.ValidPrimitiveType(i64Map.Value) { + if _, ok := knownModels[i64Map.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, i64Map.Name, i64Map.Value) + } + } + } + + for _, u32 := range model.Uint32s { + if u32.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, u32Map := range model.Uint32Maps { + if !signature.ValidPrimitiveType(u32Map.Value) { + if _, ok := knownModels[u32Map.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, u32Map.Name, u32Map.Value) + } + } + } + + for _, u64 := range model.Uint64s { + if u64.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, u64Map := range model.Uint64Maps { + if !signature.ValidPrimitiveType(u64Map.Value) { + if _, ok := knownModels[u64Map.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, u64Map.Name, u64Map.Value) + } + } + } + + for _, f32 := range model.Float32s { + if f32.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, f64 := range model.Float64s { + if f64.LimitValidator != nil { + s.hasLimitValidator = true + } + } + + for _, enumReference := range model.Enums { + if _, ok := knownEnums[enumReference.Reference]; !ok { + return fmt.Errorf("unknown %s.%s.reference: %s", model.Name, enumReference.Name, enumReference.Reference) + } + } + + for _, enumReferenceArray := range model.EnumArrays { + if _, ok := knownEnums[enumReferenceArray.Reference]; !ok { + return fmt.Errorf("unknown %s.%s.reference: %s", model.Name, enumReferenceArray.Name, enumReferenceArray.Reference) + } + } + + for _, enumReferenceMap := range model.EnumMaps { + if _, ok := knownEnums[enumReferenceMap.Reference]; !ok { + return fmt.Errorf("unknown %s.%s.reference: %s", model.Name, enumReferenceMap.Name, enumReferenceMap.Reference) + } + + if !signature.ValidPrimitiveType(enumReferenceMap.Value) { + if _, ok := knownModels[enumReferenceMap.Value]; !ok { + return fmt.Errorf("unknown %s.%s.value: %s", model.Name, enumReferenceMap.Name, enumReferenceMap.Value) + } + } + } + } + + // Map of interfaces, and check for name collisions. + knownInterfaces := make(map[string]struct{}) + for _, inter := range s.Interfaces { + _, dupe := knownModels[inter.Name] + if dupe { + return fmt.Errorf("interface name collides with a model %s", inter.Name) + } + _, dupe = knownInterfaces[inter.Name] + if dupe { + return fmt.Errorf("interface name collides with an interface %s", inter.Name) + } + knownInterfaces[inter.Name] = struct{}{} + } + + for _, inter := range s.Interfaces { + for _, f := range inter.Functions { + // Make sure the function name is ok + if !ValidLabel.MatchString(f.Name) { + return ErrInvalidFunctionName + } + + // Make sure the params exist as model. + if f.Params != "" { + f.Params = TitleCaser.String(f.Params) + if _, ok := knownModels[f.Params]; !ok { + return fmt.Errorf("unknown params in function %s: %s", f.Name, f.Params) + } + } + + // Return can either be a model or interface + if f.Return != "" { + f.Return = TitleCaser.String(f.Return) + _, foundModel := knownModels[f.Return] + _, foundInterface := knownInterfaces[f.Return] + if !foundModel && !foundInterface { + return fmt.Errorf("unknown return in function %s: %s", f.Name, f.Return) + } + } + } + } + + // Check any global functions + for _, f := range s.Functions { + // Make sure the function name is ok + if !ValidLabel.MatchString(f.Name) { + return ErrInvalidFunctionName + } + + // Make sure the params exist as model. + if f.Params != "" { + f.Params = TitleCaser.String(f.Params) + if _, ok := knownModels[f.Params]; !ok { + return fmt.Errorf("unknown params in function %s: %s", f.Name, f.Params) + } + } + + // Return can either be a model or interface + if f.Return != "" { + f.Return = TitleCaser.String(f.Return) + _, foundModel := knownModels[f.Return] + _, foundInterface := knownInterfaces[f.Return] + if !foundModel && !foundInterface { + return fmt.Errorf("unknown return in function %s: %s", f.Name, f.Return) + } + } + } + + return nil + default: + return fmt.Errorf("unknown schema version: %s", s.Version) + } + +} + +// Hash returns the SHA256 hash of the schema +func (s *Schema) Hash() ([]byte, error) { + d, err := s.Encode() + if err != nil { + return nil, err + } + + h := sha256.New() + if _, err = h.Write(d); err != nil { + return nil, err + } + return h.Sum(nil), nil +} + +// Clone returns a deep copy of the schema +func (s *Schema) Clone() (*Schema, error) { + clone := new(Schema) + encoded, err := s.Encode() + if err != nil { + return nil, err + } + if err = clone.Decode(encoded); err != nil { + return nil, err + } + return clone, nil +} + +// CloneWithDisabledAccessorsValidatorsAndModifiers returns a clone of the +// schema with all accessors, validators, and modifiers disabled +func (s *Schema) CloneWithDisabledAccessorsValidatorsAndModifiers() (*Schema, error) { + clone, err := s.Clone() + if err != nil { + return nil, err + } + clone.hasCaseModifier = false + clone.hasLimitValidator = false + clone.hasRegexValidator = false + clone.hasLengthValidator = false + for _, model := range clone.Models { + for _, modelReference := range model.Models { + modelReference.Accessor = false + } + + for _, modelReferenceArray := range model.ModelArrays { + modelReferenceArray.Accessor = + false + } + + for _, str := range model.Strings { + var accessorValue bool + str.Accessor = &accessorValue + str.CaseModifier = nil + str.LengthValidator = nil + str.RegexValidator = nil + } + + for _, strArray := range model.StringArrays { + var accessorValue bool + strArray.Accessor = &accessorValue + } + + for _, strMap := range model.StringMaps { + var accessorValue bool + strMap.Accessor = &accessorValue + } + + for _, i32 := range model.Int32s { + var accessorValue bool + i32.Accessor = &accessorValue + i32.LimitValidator = nil + } + + for _, i32Array := range model.Int32Arrays { + var accessorValue bool + i32Array.Accessor = &accessorValue + } + + for _, i32Map := range model.Int32Maps { + var accessorValue bool + i32Map.Accessor = &accessorValue + } + + for _, i64 := range model.Int64s { + var accessorValue bool + i64.Accessor = &accessorValue + i64.LimitValidator = nil + } + + for _, i64Array := range model.Int64Arrays { + var accessorValue bool + i64Array.Accessor = &accessorValue + } + + for _, i64Map := range model.Int64Maps { + var accessorValue bool + i64Map.Accessor = &accessorValue + } + + for _, u32 := range model.Uint32s { + var accessorValue bool + u32.Accessor = &accessorValue + u32.LimitValidator = nil + } + + for _, u32Array := range model.Uint32Arrays { + var accessorValue bool + u32Array.Accessor = &accessorValue + } + + for _, u32Map := range model.Uint32Maps { + var accessorValue bool + u32Map.Accessor = &accessorValue + } + + for _, u64 := range model.Uint64s { + var accessorValue bool + u64.Accessor = &accessorValue + u64.LimitValidator = nil + } + + for _, u64Array := range model.Uint64Arrays { + var accessorValue bool + u64Array.Accessor = &accessorValue + } + + for _, u64Map := range model.Uint64Maps { + var accessorValue bool + u64Map.Accessor = &accessorValue + } + + for _, f32 := range model.Float32s { + var accessorValue bool + f32.Accessor = &accessorValue + f32.LimitValidator = nil + } + + for _, f32Array := range model.Float32Arrays { + var accessorValue bool + f32Array.Accessor = &accessorValue + } + + for _, f64 := range model.Float64s { + var accessorValue bool + f64.Accessor = &accessorValue + f64.LimitValidator = nil + } + + for _, f64Array := range model.Float64Arrays { + var accessorValue bool + f64Array.Accessor = &accessorValue + } + + for _, boolean := range model.Bools { + boolean.Accessor = false + } + + for _, booleanArray := range model.BoolArrays { + booleanArray.Accessor = false + } + + for _, b := range model.Bytes { + b.Accessor = false + } + + for _, bytesArray := range model.BytesArrays { + bytesArray.Accessor = false + } + + for _, enumReference := range model.Enums { + enumReference.Accessor = false + } + + for _, enumReferenceArray := range model.EnumArrays { + enumReferenceArray.Accessor = false + } + + for _, enumReferenceMap := range model.EnumMaps { + enumReferenceMap.Accessor = false + } + } + + return clone, clone.validateAndNormalize() +} + +// validateAndNormalize validates the Schema and normalizes it +// +// Note: This function modifies the Schema in-place +func (s *Schema) validateAndNormalize() error { + + // TODO... + + return nil +} + +func (s *Schema) HasLimitValidator() bool { + return s.hasLimitValidator +} + +func (s *Schema) HasLengthValidator() bool { + return s.hasLengthValidator +} + +func (s *Schema) HasRegexValidator() bool { + return s.hasRegexValidator +} + +func (s *Schema) HasCaseModifier() bool { + return s.hasCaseModifier +} + +func ValidPrimitiveType(t string) bool { + switch t { + case "string", "int32", "int64", "uint32", "uint64", "float32", "float64", "bool", "bytes": + return true + default: + return false + } +} + +const MasterTestingSchema = ` +version = "v1alpha" +name = "HttpFetch" +tag = "alpha" + +function New { + params = "HttpConfig" + return = "HttpConnector" +} + +model HttpConfig { + int32 timeout { + default = 60 + accessor = false + } +} + +model HttpResponse { + string_map Headers { + value = "StringList" + accessor = false + } + int32 StatusCode { + default = 0 + accessor = false + } + bytes Body { + initial_size = 0 + accessor = false + } +} + +model StringList { + string_array Values { + initial_size = 0 + accessor = false + } +} + +model ConnectionDetails { + string url { + default = "https://google.com" + accessor = false + } +} + +interface HttpConnector { + function Fetch { + params = "ConnectionDetails" + return = "HttpResponse" + } +} + +` diff --git a/extension/schema_test.go b/extension/schema_test.go new file mode 100644 index 00000000..e4359e02 --- /dev/null +++ b/extension/schema_test.go @@ -0,0 +1,48 @@ +//go:build !integration + +/* + Copyright 2023 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package extension + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSchema(t *testing.T) { + s := new(Schema) + err := s.Decode([]byte(MasterTestingSchema)) + require.NoError(t, err) + + require.NoError(t, s.Validate()) + + assert.Equal(t, V1AlphaVersion, s.Version) + assert.Equal(t, "HttpFetch", s.Name) + assert.Equal(t, "alpha", s.Tag) + + // Make sure there's a global function defined... + assert.Equal(t, 1, len(s.Functions)) + + assert.Equal(t, "New", s.Functions[0].Name) + assert.Equal(t, "HttpConfig", s.Functions[0].Params) + assert.Equal(t, "HttpConnector", s.Functions[0].Return) + + // NB You could test the models, but that should be done in signature already. + +} diff --git a/go.mod b/go.mod index 919c6588..85839b13 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/loopholelabs/scale-extension-interfaces v0.0.0-20230920094333-3a483b301bf4 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/zclconf/go-cty v1.13.2 // indirect diff --git a/go.sum b/go.sum index 5924c588..c2f9128d 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= github.com/loopholelabs/polyglot v1.1.3 h1:WUTcSZ2TQ1lv7CZ4I9nHFBUjf0hKJN+Yfz1rZZJuTP0= github.com/loopholelabs/polyglot v1.1.3/go.mod h1:EA88BEkIluKHAWxhyOV88xXz68YkRdo9IzZ+1dj+7Ao= +github.com/loopholelabs/scale-extension-interfaces v0.0.0-20230920094333-3a483b301bf4 h1:leEQ1uJgTcyKdJojndhZRBAbOxcUz3u5Ol05Kz6tedQ= +github.com/loopholelabs/scale-extension-interfaces v0.0.0-20230920094333-3a483b301bf4/go.mod h1:/qjvg9RglZaRhw3cE+dj6AaZRYHCOR+ohuZX8MSxD8E= github.com/loopholelabs/scale-signature-interfaces v0.1.7 h1:aOJJZpCKn/Q5Q0Gj+/Q6c7/iABEbojjbCzIqw7Mxyi0= github.com/loopholelabs/scale-signature-interfaces v0.1.7/go.mod h1:3XLMjJjBf5lYxMtNKk+2XAWye4UyrkvUBJ9L6x2QCAk= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= diff --git a/instance.go b/instance.go index 11083116..778c6aa3 100644 --- a/instance.go +++ b/instance.go @@ -80,6 +80,8 @@ func newInstance[T interfaces.Signature](ctx context.Context, runtime *Scale[T], } func (i *Instance[T]) Run(ctx context.Context, signature T) error { + i.runtime.resetExtensions() + m, err := i.head.getModule(signature) if err != nil { return fmt.Errorf("failed to get module for function '%s': %w", i.head.template.identifier, err) diff --git a/scale.go b/scale.go index 40a3c977..187844be 100644 --- a/scale.go +++ b/scale.go @@ -28,6 +28,8 @@ import ( "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + + extension "github.com/loopholelabs/scale-extension-interfaces" ) // Next is the next function in the middleware chain. It's meant to be implemented @@ -68,6 +70,13 @@ func (r *Scale[T]) Instance(next ...Next[T]) (*Instance[T], error) { return newInstance(r.config.context, r, next...) } +// Reset any extensions between executions. +func (r *Scale[T]) resetExtensions() { + for _, ext := range r.config.extensions { + ext.Reset() + } +} + func (r *Scale[T]) init() error { err := r.config.validate() if err != nil { @@ -82,7 +91,30 @@ func (r *Scale[T]) init() error { r.moduleConfig = r.moduleConfig.WithStderr(r.config.stderr) } - envHostModuleBuilder := r.runtime.NewHostModuleBuilder("env"). + envModule := r.runtime.NewHostModuleBuilder("env") + + // Install any extensions... + for _, ext := range r.config.extensions { + fns := ext.Init() + for name, fn := range fns { + wfn := func(n string, f extension.InstallableFunc) func(context.Context, api.Module, []uint64) { + return func(ctx context.Context, mod api.Module, params []uint64) { + mem := mod.Memory() + resize := func(name string, size uint64) (uint64, error) { + w, err := mod.ExportedFunction(name).Call(context.Background(), size) + return w[0], err + } + f(mem, resize, params) + } + }(name, fn) + + envModule.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc(wfn), []api.ValueType{api.ValueTypeI64, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI64}). + WithParameterNames("instance", "pointer", "length").Export(name) + } + } + + envHostModuleBuilder := envModule. NewFunctionBuilder(). WithGoModuleFunction(api.GoModuleFunc(r.next), []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{}). WithParameterNames("pointer", "length").Export("next") diff --git a/scalefile/scalefile.go b/scalefile/scalefile.go index 7bc1f664..65c0cfa6 100644 --- a/scalefile/scalefile.go +++ b/scalefile/scalefile.go @@ -47,16 +47,23 @@ type SignatureSchema struct { Tag string `hcl:"tag,attr"` } +type ExtensionSchema struct { + Organization string `hcl:"organization,optional"` + Name string `hcl:"name,attr"` + Tag string `hcl:"tag,attr"` +} + type Schema struct { - Version string `hcl:"version,attr"` - Name string `hcl:"name,attr"` - Tag string `hcl:"tag,attr"` - Language string `hcl:"language,attr"` - Signature SignatureSchema `hcl:"signature,block"` - Stateless bool `hcl:"stateless,optional"` - Function string `hcl:"function,attr"` - Initialize string `hcl:"initialize,attr"` - Description string `hcl:"description,optional"` + Version string `hcl:"version,attr"` + Name string `hcl:"name,attr"` + Tag string `hcl:"tag,attr"` + Language string `hcl:"language,attr"` + Signature SignatureSchema `hcl:"signature,block"` + Stateless bool `hcl:"stateless,optional"` + Function string `hcl:"function,attr"` + Initialize string `hcl:"initialize,attr"` + Description string `hcl:"description,optional"` + Extensions []ExtensionSchema `hcl:"extension,block"` } func ReadSchema(path string) (*Schema, error) { diff --git a/storage/extension.go b/storage/extension.go new file mode 100644 index 00000000..4db6201f --- /dev/null +++ b/storage/extension.go @@ -0,0 +1,355 @@ +/* + Copyright 2022 Loophole Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package storage is used to store and retrieve built Scale Functions +package storage + +import ( + "encoding/hex" + "fmt" + "os" + "path" + "path/filepath" + + "github.com/loopholelabs/scale/extension" + "github.com/loopholelabs/scale/extension/generator" + "github.com/loopholelabs/scale/scalefunc" +) + +const ( + ExtensionDirectory = "extensions" +) + +var ( + DefaultExtension *ExtensionStorage +) + +type Extension struct { + Name string + Tag string + Schema *extension.Schema + Hash string + Organization string +} + +type ExtensionStorage struct { + Directory string +} + +func init() { + homeDir, err := os.UserHomeDir() + if err != nil { + panic(err) + } + DefaultExtension, err = NewExtension(path.Join(homeDir, DefaultDirectory, ExtensionDirectory)) + if err != nil { + panic(err) + } +} + +func NewExtension(baseDirectory string) (*ExtensionStorage, error) { + err := os.MkdirAll(baseDirectory, 0755) + if err != nil { + if !os.IsExist(err) { + return nil, err + } + } + + return &ExtensionStorage{ + Directory: baseDirectory, + }, nil +} + +// Get returns the Scale Extension with the given name, tag, and organization. +// The hash parameter is optional and can be used to check for a specific hash. +func (s *ExtensionStorage) Get(name string, tag string, org string, hash string) (*Extension, error) { + if name == "" || !scalefunc.ValidString(name) { + return nil, ErrInvalidName + } + + if tag == "" || !scalefunc.ValidString(tag) { + return nil, ErrInvalidTag + } + + if org == "" || !scalefunc.ValidString(org) { + return nil, ErrInvalidOrganization + } + + if hash != "" { + f := s.extensionName(name, tag, org, hash) + p := s.fullPath(f) + + stat, err := os.Stat(p) + if err != nil { + return nil, err + } + + if !stat.IsDir() { + return nil, fmt.Errorf("found extension is a file not a directory %s/%s:%s", org, name, tag) + } + + sig, err := extension.ReadSchema(path.Join(p, "extension")) + if err != nil { + return nil, err + } + + return &Extension{ + Name: name, + Tag: tag, + Schema: sig, + Hash: hash, + Organization: org, + }, nil + } + + f := s.extensionSearch(name, tag, org) + p := s.fullPath(f) + + matches, err := filepath.Glob(p) + if err != nil { + return nil, err + } + + if len(matches) == 0 { + return nil, nil + } + + if len(matches) > 1 { + return nil, fmt.Errorf("multiple matches found for %s/%s:%s", org, name, tag) + } + + stat, err := os.Stat(matches[0]) + if err != nil { + return nil, err + } + + if !stat.IsDir() { + return nil, fmt.Errorf("found extension is a file not a directory %s/%s:%s", org, name, tag) + } + + sig, err := extension.ReadSchema(path.Join(matches[0], "extension")) + if err != nil { + return nil, err + } + + return &Extension{ + Name: name, + Tag: tag, + Schema: sig, + Hash: getHashFromName(filepath.Base(matches[0])), + Organization: getOrgFromName(filepath.Base(matches[0])), + }, nil +} + +func (s *ExtensionStorage) Path(name string, tag string, org string, hash string) (string, error) { + if name == "" || !scalefunc.ValidString(name) { + return "", ErrInvalidName + } + + if tag == "" || !scalefunc.ValidString(tag) { + return "", ErrInvalidTag + } + + if org == "" || !scalefunc.ValidString(org) { + return "", ErrInvalidOrganization + } + + if hash != "" { + f := s.extensionName(name, tag, org, hash) + p := s.fullPath(f) + + stat, err := os.Stat(p) + if err != nil { + return "", err + } + + if !stat.IsDir() { + return "", fmt.Errorf("found extension is a file not a directory %s/%s:%s", org, name, tag) + } + + return p, nil + } + + f := s.extensionSearch(name, tag, org) + p := s.fullPath(f) + + matches, err := filepath.Glob(p) + if err != nil { + return "", err + } + + if len(matches) == 0 { + return "", nil + } + + if len(matches) > 1 { + return "", fmt.Errorf("multiple matches found for %s/%s:%s", org, name, tag) + } + + stat, err := os.Stat(matches[0]) + if err != nil { + return "", err + } + + if !stat.IsDir() { + return "", fmt.Errorf("found extension is a file not a directory %s/%s:%s", org, name, tag) + } + + return matches[0], nil +} + +// Put stores the Scale Extension with the given name, tag, organization +func (s *ExtensionStorage) Put(name string, tag string, org string, sig *extension.Schema) error { + hash, err := sig.Hash() + if err != nil { + return err + } + + hashString := hex.EncodeToString(hash) + + f := s.extensionName(name, tag, org, hashString) + directory := s.fullPath(f) + err = os.MkdirAll(directory, 0755) + if err != nil { + return err + } + + err = GenerateExtension(sig, name, tag, org, directory) + if err != nil { + return err + } + + return nil +} + +// Delete removes the Scale Extension with the given name, tag, org, and hash +func (s *ExtensionStorage) Delete(name string, tag string, org string, hash string) error { + return os.RemoveAll(s.fullPath(s.extensionName(name, tag, org, hash))) +} + +// List returns all the Scale Extensions stored in the storage +func (s *ExtensionStorage) List() ([]Extension, error) { + entries, err := os.ReadDir(s.Directory) + if err != nil { + return nil, fmt.Errorf("failed to read storage directory %s: %w", s.Directory, err) + } + var scaleExtensionEntries []Extension + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + sig, err := extension.ReadSchema(path.Join(s.fullPath(entry.Name()), "extension")) + if err != nil { + return nil, fmt.Errorf("failed to decode scale extension %s: %w", s.fullPath(entry.Name()), err) + } + scaleExtensionEntries = append(scaleExtensionEntries, Extension{ + Name: getNameFromName(entry.Name()), + Tag: getTagFromName(entry.Name()), + Schema: sig, + Hash: getHashFromName(entry.Name()), + Organization: getOrgFromName(entry.Name()), + }) + } + return scaleExtensionEntries, nil +} + +func (s *ExtensionStorage) fullPath(p string) string { + return path.Join(s.Directory, p) +} + +func (s *ExtensionStorage) extensionName(name string, tag string, org string, hash string) string { + return fmt.Sprintf("%s_%s_%s_%s_extension", org, name, tag, hash) +} + +func (s *ExtensionStorage) extensionSearch(name string, tag string, org string) string { + return fmt.Sprintf("%s_%s_%s_*_extension", org, name, tag) +} + +// GenerateExtension generates the extension files and writes them to +// the given path. +func GenerateExtension(ext *extension.Schema, name string, tag string, org string, directory string) error { + encoded, err := ext.Encode() + if err != nil { + return err + } + + err = os.WriteFile(path.Join(directory, "extension"), encoded, 0644) + if err != nil { + return err + } + + err = os.MkdirAll(path.Join(directory, "golang", "guest"), 0755) + if err != nil { + return err + } + + err = os.MkdirAll(path.Join(directory, "rust", "guest"), 0755) + if err != nil { + return err + } + + err = os.MkdirAll(path.Join(directory, "golang", "host"), 0755) + if err != nil { + return err + } + + guestPackage, err := generator.GenerateGuestLocal(&generator.Options{ + Extension: ext, + GolangPackageImportPath: "extension", + GolangPackageName: ext.Name, + GolangPackageVersion: "v0.1.0", + + RustPackageName: fmt.Sprintf("%s_%s_%s_guest", org, name, tag), + RustPackageVersion: "0.1.0", + }) + if err != nil { + return err + } + + for _, file := range guestPackage.GolangFiles { + err = os.WriteFile(path.Join(directory, "golang", "guest", file.Path()), file.Data(), 0644) + if err != nil { + return err + } + } + + for _, file := range guestPackage.RustFiles { + err = os.WriteFile(path.Join(directory, "rust", "guest", file.Path()), file.Data(), 0644) + if err != nil { + return err + } + } + + hostPackage, err := generator.GenerateHostLocal(&generator.Options{ + Extension: ext, + GolangPackageImportPath: "extension", + GolangPackageName: ext.Name, + GolangPackageVersion: "v0.1.0", + }) + if err != nil { + return err + } + + for _, file := range hostPackage.GolangFiles { + err = os.WriteFile(path.Join(directory, "golang", "host", file.Path()), file.Data(), 0644) + if err != nil { + return err + } + } + + return nil +}