From a9564f898cc3c5ff5f5c4f6ae2b523c99b94b7c3 Mon Sep 17 00:00:00 2001 From: Sam L Date: Sun, 19 May 2019 17:26:06 -0400 Subject: [PATCH] Update TypeExtension definition --- go.mod | 1 - language/ast/definitions.go | 38 ---- language/ast/type_extensions.go | 53 +++++ language/parser/parser.go | 279 +++++++++++++++++++++++++- language/parser/schema_parser_test.go | 135 +++++-------- 5 files changed, 384 insertions(+), 122 deletions(-) delete mode 100644 go.mod create mode 100644 language/ast/type_extensions.go diff --git a/go.mod b/go.mod deleted file mode 100644 index 399b200d..00000000 --- a/go.mod +++ /dev/null @@ -1 +0,0 @@ -module github.com/graphql-go/graphql diff --git a/language/ast/definitions.go b/language/ast/definitions.go index e16cf18d..426a6ea4 100644 --- a/language/ast/definitions.go +++ b/language/ast/definitions.go @@ -148,44 +148,6 @@ func (vd *VariableDefinition) GetLoc() *Location { return vd.Loc } -// TypeExtensionDefinition implements Node, Definition -type TypeExtensionDefinition struct { - Kind string - Loc *Location - Definition *ObjectDefinition -} - -func NewTypeExtensionDefinition(def *TypeExtensionDefinition) *TypeExtensionDefinition { - if def == nil { - def = &TypeExtensionDefinition{} - } - return &TypeExtensionDefinition{ - Kind: kinds.TypeExtensionDefinition, - Loc: def.Loc, - Definition: def.Definition, - } -} - -func (def *TypeExtensionDefinition) GetKind() string { - return def.Kind -} - -func (def *TypeExtensionDefinition) GetLoc() *Location { - return def.Loc -} - -func (def *TypeExtensionDefinition) GetVariableDefinitions() []*VariableDefinition { - return []*VariableDefinition{} -} - -func (def *TypeExtensionDefinition) GetSelectionSet() *SelectionSet { - return &SelectionSet{} -} - -func (def *TypeExtensionDefinition) GetOperation() string { - return "" -} - // DirectiveDefinition implements Node, Definition type DirectiveDefinition struct { Kind string diff --git a/language/ast/type_extensions.go b/language/ast/type_extensions.go new file mode 100644 index 00000000..f370527e --- /dev/null +++ b/language/ast/type_extensions.go @@ -0,0 +1,53 @@ +package ast + +import "github.com/graphql-go/graphql/language/kinds" + +type TypeExtension interface { + GetKind() string + GetLoc() *Location +} + +var _ TypeExtension = (*ScalarDefinition)(nil) +var _ TypeExtension = (*ObjectDefinition)(nil) +var _ TypeExtension = (*InterfaceDefinition)(nil) +var _ TypeExtension = (*UnionDefinition)(nil) +var _ TypeExtension = (*EnumDefinition)(nil) +var _ TypeExtension = (*InputObjectDefinition)(nil) + +// TypeExtensionDefinition implements Node, Definition +type TypeExtensionDefinition struct { + Kind string + Loc *Location + Definition TypeExtension +} + +func NewTypeExtensionDefinition(def *TypeExtensionDefinition) *TypeExtensionDefinition { + if def == nil { + def = &TypeExtensionDefinition{} + } + return &TypeExtensionDefinition{ + Kind: kinds.TypeExtensionDefinition, + Loc: def.Loc, + Definition: def.Definition, + } +} + +func (def *TypeExtensionDefinition) GetKind() string { + return def.Kind +} + +func (def *TypeExtensionDefinition) GetLoc() *Location { + return def.Loc +} + +func (def *TypeExtensionDefinition) GetVariableDefinitions() []*VariableDefinition { + return []*VariableDefinition{} +} + +func (def *TypeExtensionDefinition) GetSelectionSet() *SelectionSet { + return &SelectionSet{} +} + +func (def *TypeExtensionDefinition) GetOperation() string { + return "" +} diff --git a/language/parser/parser.go b/language/parser/parser.go index 4ee1577c..c5be77ec 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -924,6 +924,31 @@ func parseOperationTypeDefinition(parser *Parser) (interface{}, error) { }), nil } +/** + * ScalarTypeExtensionDefinition : extend scalar Name Directives + */ +func parseScalarTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.SCALAR) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + def := ast.NewScalarDefinition(&ast.ScalarDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + }) + return def, nil +} + /** * ScalarTypeDefinition : Description? scalar Name Directives? */ @@ -954,6 +979,58 @@ func parseScalarTypeDefinition(parser *Parser) (ast.Node, error) { return def, nil } +/** + * ObjectTypeExtensionDefinition : + * extend type Name ImplementsInterfaces? Directives? { FieldDefinition+ } + * extend type Name ImplementsInterfaces? Directives + * extend type Name ImplementsInterfaces + */ +func parseObjectTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.TYPE) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + interfaces, err := parseImplementsInterfaces(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + iFields, err := reverse(parser, + lexer.BRACE_L, parseFieldDefinition, lexer.BRACE_R, + false, + ) + if err != nil { + return nil, err + } + fields := []*ast.FieldDefinition{} + for _, iField := range iFields { + if iField != nil { + fields = append(fields, iField.(*ast.FieldDefinition)) + } + } + + // Must have at least one defined + if len(fields) == 0 && len(directives) == 0 && len(interfaces) == 0 { + return nil, unexpected(parser, parser.Token) + } + + return ast.NewObjectDefinition(&ast.ObjectDefinition{ + Name: name, + Loc: loc(parser, start), + Interfaces: interfaces, + Directives: directives, + Fields: fields, + }), nil +} + /** * ObjectTypeDefinition : * Description? @@ -1145,6 +1222,49 @@ func parseInputValueDef(parser *Parser) (interface{}, error) { }), nil } +/** + * InterfaceTypeExtensionDefinition : + * extend interface Name Directives? { FieldDefinition+ } + * extend interface Name Directives + */ +func parseInterfaceTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.INTERFACE) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + iFields, err := reverse(parser, + lexer.BRACE_L, parseFieldDefinition, lexer.BRACE_R, + false, + ) + if err != nil { + return nil, err + } + fields := []*ast.FieldDefinition{} + for _, iField := range iFields { + if iField != nil { + fields = append(fields, iField.(*ast.FieldDefinition)) + } + } + if len(fields) == 0 && len(directives) == 0 { + return nil, unexpected(parser, parser.Token) + } + return ast.NewInterfaceDefinition(&ast.InterfaceDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + Fields: fields, + }), nil +} + /** * InterfaceTypeDefinition : * Description? @@ -1190,6 +1310,49 @@ func parseInterfaceTypeDefinition(parser *Parser) (ast.Node, error) { }), nil } +/** + * UnionTypeExtension : + * extend union Name Directives? = UnionMembers + * extend union Name Directives + */ +func parseUnionTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.UNION) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + _, err = expect(parser, lexer.EQUALS) + if err != nil { + return ast.NewUnionDefinition(&ast.UnionDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + Types: []*ast.Named{}, + }), nil + } + types, err := parseUnionMembers(parser) + if err != nil { + return nil, err + } + if len(types) == 0 && len(directives) == 0 { + return nil, unexpected(parser, parser.Token) + } + return ast.NewUnionDefinition(&ast.UnionDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + Types: types, + }), nil +} + /** * UnionTypeDefinition : Description? union Name Directives? = UnionMembers */ @@ -1250,6 +1413,49 @@ func parseUnionMembers(parser *Parser) ([]*ast.Named, error) { return members, nil } +/** + * EnumTypeExtensionDefinition : + * extend enum Name Directives? { EnumValueDefinition+ } + * extend enum Name Directives + */ +func parseEnumTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.ENUM) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + iEnumValueDefs, err := reverse(parser, + lexer.BRACE_L, parseEnumValueDefinition, lexer.BRACE_R, + false, + ) + if err != nil { + return nil, err + } + values := []*ast.EnumValueDefinition{} + for _, iEnumValueDef := range iEnumValueDefs { + if iEnumValueDef != nil { + values = append(values, iEnumValueDef.(*ast.EnumValueDefinition)) + } + } + if len(values) == 0 && len(directives) == 0 { + return nil, unexpected(parser, parser.Token) + } + return ast.NewEnumDefinition(&ast.EnumDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + Values: values, + }), nil +} + /** * EnumTypeDefinition : Description? enum Name Directives? { EnumValueDefinition+ } */ @@ -1320,6 +1526,49 @@ func parseEnumValueDefinition(parser *Parser) (interface{}, error) { }), nil } +/** + * InputObjectTypeExtensionDefinition : + * extend input Name Directives? { InputValueDefinition+ } + * extend input Name Directives + */ +func parseInputObjectTypeExtensionDefinition(parser *Parser) (ast.Node, error) { + start := parser.Token.Start + _, err := expectKeyWord(parser, lexer.INPUT) + if err != nil { + return nil, err + } + name, err := parseName(parser) + if err != nil { + return nil, err + } + directives, err := parseDirectives(parser) + if err != nil { + return nil, err + } + iInputValueDefinitions, err := reverse(parser, + lexer.BRACE_L, parseInputValueDef, lexer.BRACE_R, + false, + ) + if err != nil { + return nil, err + } + fields := []*ast.InputValueDefinition{} + for _, iInputValueDefinition := range iInputValueDefinitions { + if iInputValueDefinition != nil { + fields = append(fields, iInputValueDefinition.(*ast.InputValueDefinition)) + } + } + if len(fields) == 0 && len(directives) == 0 { + return nil, unexpected(parser, parser.Token) + } + return ast.NewInputObjectDefinition(&ast.InputObjectDefinition{ + Name: name, + Directives: directives, + Loc: loc(parser, start), + Fields: fields, + }), nil +} + /** * InputObjectTypeDefinition : * - Description? input Name Directives? { InputValueDefinition+ } @@ -1365,7 +1614,13 @@ func parseInputObjectTypeDefinition(parser *Parser) (ast.Node, error) { } /** - * TypeExtensionDefinition : extend ObjectTypeDefinition + * TypeExtensionDefinition : + * ScalarTypeExtension + * ObjectTypeExtension + * InterfaceTypeExtension + * UnionTypeExtension + * EnumTypeExtension + * InputObjectTypeExtension */ func parseTypeExtensionDefinition(parser *Parser) (ast.Node, error) { start := parser.Token.Start @@ -1374,13 +1629,31 @@ func parseTypeExtensionDefinition(parser *Parser) (ast.Node, error) { return nil, err } - definition, err := parseObjectTypeDefinition(parser) + token := parser.Token + var definition ast.TypeExtension + switch token.Value { + case lexer.SCALAR: + definition, err = parseScalarTypeExtensionDefinition(parser) + case lexer.TYPE: + definition, err = parseObjectTypeExtensionDefinition(parser) + case lexer.INTERFACE: + definition, err = parseInterfaceTypeExtensionDefinition(parser) + case lexer.UNION: + definition, err = parseUnionTypeExtensionDefinition(parser) + case lexer.ENUM: + definition, err = parseEnumTypeExtensionDefinition(parser) + case lexer.INPUT: + definition, err = parseInputObjectTypeExtensionDefinition(parser) + default: + return nil, unexpected(parser, token) + } if err != nil { return nil, err } + return ast.NewTypeExtensionDefinition(&ast.TypeExtensionDefinition{ Loc: loc(parser, start), - Definition: definition.(*ast.ObjectDefinition), + Definition: definition, }), nil } diff --git a/language/parser/schema_parser_test.go b/language/parser/schema_parser_test.go index b7d31f3d..3bb6e950 100644 --- a/language/parser/schema_parser_test.go +++ b/language/parser/schema_parser_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + "github.com/go-test/deep" "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/location" @@ -74,7 +75,7 @@ type Hello { } } -func TestSchemaParser_SimpleExtension(t *testing.T) { +func TestSchemaParser_SimpleTypeExtension(t *testing.T) { body := ` extend type Hello { @@ -116,41 +117,39 @@ extend type Hello { }), }, }) - if !reflect.DeepEqual(astDoc, expected) { - t.Fatalf("unexpected document, expected: %v, got: %v", expected, astDoc) + if diff := deep.Equal(astDoc, expected); diff != nil { + t.Fatal(diff) } } func TestSchemaParser_SimpleInputExtension(t *testing.T) { body := ` -extend input hello { - world: string +extend input Hello { + world: String } ` astDoc := parse(t, body) expected := ast.NewDocument(&ast.Document{ - Loc: testLoc(1, 38), + Loc: testLoc(1, 39), Definitions: []ast.Node{ ast.NewTypeExtensionDefinition(&ast.TypeExtensionDefinition{ Loc: testLoc(1, 38), - Definition: ast.NewObjectDefinition(&ast.ObjectDefinition{ + Definition: ast.NewInputObjectDefinition(&ast.InputObjectDefinition{ Loc: testLoc(8, 38), Name: ast.NewName(&ast.Name{ Value: "Hello", - Loc: testLoc(13, 18), + Loc: testLoc(14, 19), }), Directives: []*ast.Directive{}, - Interfaces: []*ast.Named{}, - Fields: []*ast.FieldDefinition{ - ast.NewFieldDefinition(&ast.FieldDefinition{ + Fields: []*ast.InputValueDefinition{ + ast.NewInputValueDefinition(&ast.InputValueDefinition{ Loc: testLoc(23, 36), Name: ast.NewName(&ast.Name{ Value: "world", Loc: testLoc(23, 28), }), Directives: []*ast.Directive{}, - Arguments: []*ast.InputValueDefinition{}, Type: ast.NewNamed(&ast.Named{ Loc: testLoc(30, 36), Name: ast.NewName(&ast.Name{ @@ -164,46 +163,45 @@ extend input hello { }), }, }) - if !reflect.DeepEqual(astDoc, expected) { - t.Fatalf("unexpected document, expected: %v, got: %v", expected, astDoc) + if diff := deep.Equal(astDoc, expected); diff != nil { + t.Fatal(diff) } } func TestSchemaParser_SimpleInterfaceExtension(t *testing.T) { body := ` -extend interface hello { - world: string +extend interface Hello { + world: String } ` astDoc := parse(t, body) expected := ast.NewDocument(&ast.Document{ - Loc: testLoc(1, 38), + Loc: testLoc(1, 43), Definitions: []ast.Node{ ast.NewTypeExtensionDefinition(&ast.TypeExtensionDefinition{ - Loc: testLoc(1, 38), - Definition: ast.NewObjectDefinition(&ast.ObjectDefinition{ - Loc: testLoc(8, 38), + Loc: testLoc(1, 42), + Definition: ast.NewInterfaceDefinition(&ast.InterfaceDefinition{ + Loc: testLoc(8, 42), Name: ast.NewName(&ast.Name{ Value: "Hello", - Loc: testLoc(13, 18), + Loc: testLoc(18, 23), }), Directives: []*ast.Directive{}, - Interfaces: []*ast.Named{}, Fields: []*ast.FieldDefinition{ ast.NewFieldDefinition(&ast.FieldDefinition{ - Loc: testLoc(23, 36), + Loc: testLoc(27, 40), Name: ast.NewName(&ast.Name{ Value: "world", - Loc: testLoc(23, 28), + Loc: testLoc(27, 32), }), Directives: []*ast.Directive{}, Arguments: []*ast.InputValueDefinition{}, Type: ast.NewNamed(&ast.Named{ - Loc: testLoc(30, 36), + Loc: testLoc(34, 40), Name: ast.NewName(&ast.Name{ Value: "String", - Loc: testLoc(30, 36), + Loc: testLoc(34, 40), }), }), }), @@ -212,100 +210,77 @@ extend interface hello { }), }, }) - if !reflect.DeepEqual(astDoc, expected) { - t.Fatalf("unexpected document, expected: %v, got: %v", expected, astDoc) + if diff := deep.Equal(astDoc, expected); diff != nil { + t.Fatal(diff) } } func TestSchemaParser_SimpleScalarExtension(t *testing.T) { - body := ` -directive @example on SCALAR -extend scalar string @example` + body := `extend scalar string @example` astDoc := parse(t, body) expected := ast.NewDocument(&ast.Document{ - Loc: testLoc(1, 38), + Loc: testLoc(0, 29), Definitions: []ast.Node{ ast.NewTypeExtensionDefinition(&ast.TypeExtensionDefinition{ - Loc: testLoc(1, 38), - Definition: ast.NewObjectDefinition(&ast.ObjectDefinition{ - Loc: testLoc(8, 38), + Loc: testLoc(0, 29), + Definition: ast.NewScalarDefinition(&ast.ScalarDefinition{ + Loc: testLoc(7, 29), Name: ast.NewName(&ast.Name{ - Value: "Hello", - Loc: testLoc(13, 18), + Value: "string", + Loc: testLoc(14, 20), }), - Directives: []*ast.Directive{}, - Interfaces: []*ast.Named{}, - Fields: []*ast.FieldDefinition{ - ast.NewFieldDefinition(&ast.FieldDefinition{ - Loc: testLoc(23, 36), + Directives: []*ast.Directive{ + ast.NewDirective(&ast.Directive{ + Loc: testLoc(21, 29), Name: ast.NewName(&ast.Name{ - Value: "world", - Loc: testLoc(23, 28), - }), - Directives: []*ast.Directive{}, - Arguments: []*ast.InputValueDefinition{}, - Type: ast.NewNamed(&ast.Named{ - Loc: testLoc(30, 36), - Name: ast.NewName(&ast.Name{ - Value: "String", - Loc: testLoc(30, 36), - }), + Value: "example", + Loc: testLoc(22, 29), }), + Arguments: []*ast.Argument{}, }), }, }), }), }, }) - if !reflect.DeepEqual(astDoc, expected) { - t.Fatalf("unexpected document, expected: %v, got: %v", expected, astDoc) + if diff := deep.Equal(astDoc, expected); diff != nil { + t.Fatal(diff) } } func TestSchemaParser_SimpleUnionExtension(t *testing.T) { - body := ` -directive @example on UNION -extend union @example` + body := `extend union Hello @example` astDoc := parse(t, body) expected := ast.NewDocument(&ast.Document{ - Loc: testLoc(1, 38), + Loc: testLoc(0, 27), Definitions: []ast.Node{ ast.NewTypeExtensionDefinition(&ast.TypeExtensionDefinition{ - Loc: testLoc(1, 38), - Definition: ast.NewObjectDefinition(&ast.ObjectDefinition{ - Loc: testLoc(8, 38), + Loc: testLoc(0, 27), + Definition: ast.NewUnionDefinition(&ast.UnionDefinition{ + Loc: testLoc(7, 27), Name: ast.NewName(&ast.Name{ Value: "Hello", Loc: testLoc(13, 18), }), - Directives: []*ast.Directive{}, - Interfaces: []*ast.Named{}, - Fields: []*ast.FieldDefinition{ - ast.NewFieldDefinition(&ast.FieldDefinition{ - Loc: testLoc(23, 36), + Directives: []*ast.Directive{ + ast.NewDirective(&ast.Directive{ + Loc: testLoc(19, 27), Name: ast.NewName(&ast.Name{ - Value: "world", - Loc: testLoc(23, 28), - }), - Directives: []*ast.Directive{}, - Arguments: []*ast.InputValueDefinition{}, - Type: ast.NewNamed(&ast.Named{ - Loc: testLoc(30, 36), - Name: ast.NewName(&ast.Name{ - Value: "String", - Loc: testLoc(30, 36), - }), + Loc: testLoc(20, 27), + Value: "example", }), + Arguments: []*ast.Argument{}, }), }, + Types: []*ast.Named{}, }), }), }, }) - if !reflect.DeepEqual(astDoc, expected) { - t.Fatalf("unexpected document, expected: %v, got: %v", expected, astDoc) + if diff := deep.Equal(astDoc, expected); diff != nil { + t.Fatal(diff) } }