diff --git a/app.go b/app.go index 5f0007a35f..d7bb6abd4c 100644 --- a/app.go +++ b/app.go @@ -121,6 +121,8 @@ type App struct { mountFields *mountFields // Indicates if the value was explicitly configured configured Config + // customConstraints is a list of external constraints + customConstraints []CustomConstraint } // Config is a struct holding the server settings. @@ -588,6 +590,11 @@ func (app *App) NewCtxFunc(function func(app *App) CustomCtx) { app.newCtxFunc = function } +// RegisterCustomConstraint allows to register custom constraint. +func (app *App) RegisterCustomConstraint(constraint CustomConstraint) { + app.customConstraints = append(app.customConstraints, constraint) +} + // You can register custom binders to use as Bind().Custom("name"). // They should be compatible with CustomBinder interface. func (app *App) RegisterCustomBinder(binder CustomBinder) { diff --git a/app_test.go b/app_test.go index e61d23a95d..18016f1a8f 100644 --- a/app_test.go +++ b/app_test.go @@ -2,7 +2,7 @@ // 🤖 Github Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io -//nolint:bodyclose // Much easier to just ignore memory leaks in tests +//nolint:bodyclose, goconst // Much easier to just ignore memory leaks in tests package fiber import ( @@ -178,6 +178,61 @@ func Test_App_Errors(t *testing.T) { } } +type customConstraint struct{} + +func (*customConstraint) Name() string { + return "test" +} + +func (*customConstraint) Execute(param string, args ...string) bool { + if param == "test" && len(args) == 1 && args[0] == "test" { + return true + } + + if len(args) == 0 && param == "c" { + return true + } + + return false +} + +func Test_App_CustomConstraint(t *testing.T) { + app := New() + app.RegisterCustomConstraint(&customConstraint{}) + + app.Get("/test/:param", func(c Ctx) error { + return c.SendString("test") + }) + + app.Get("/test2/:param", func(c Ctx) error { + return c.SendString("test") + }) + + app.Get("/test3/:param", func(c Ctx) error { + return c.SendString("test") + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/test/test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/test2", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/c", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/cc", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/test3/cc", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, 404, resp.StatusCode, "Status code") +} + func Test_App_ErrorHandler_Custom(t *testing.T) { t.Parallel() app := New(Config{ diff --git a/docs/api/app.md b/docs/api/app.md index 29022a3104..2e58a426d1 100644 --- a/docs/api/app.md +++ b/docs/api/app.md @@ -617,6 +617,16 @@ ln = tls.NewListener(ln, &tls.Config{Certificates: []tls.Certificate{cer}}) app.Listener(ln) ``` +## RegisterCustomConstraint + +RegisterCustomConstraint allows to register custom constraint. + +```go title="Signature" +func (app *App) RegisterCustomConstraint(constraint CustomConstraint) +``` + +See [Custom Constraint](../guide/routing.md#custom-constraint) section for more information. + ## Test Testing your application is done with the **Test** method. Use this method for creating `_test.go` files or when you need to debug your routing logic. The default timeout is `1s` if you want to disable a timeout altogether, pass `-1` as a second argument. diff --git a/docs/guide/routing.md b/docs/guide/routing.md index 31d187aeb5..70d746a317 100644 --- a/docs/guide/routing.md +++ b/docs/guide/routing.md @@ -240,6 +240,59 @@ app.Get("/:test?", func(c fiber.Ctx) error { // Cannot GET /7.0 ``` +**Custom Constraint Example** + +Custom constraints can be added to Fiber using the `app.RegisterCustomConstraint` method. Your constraints have to be compatible with the `CustomConstraint` interface. + +It is a good idea to add external constraints to your project once you want to add more specific rules to your routes. +For example, you can add a constraint to check if a parameter is a valid ULID. + +```go +// CustomConstraint is an interface for custom constraints +type CustomConstraint interface { + // Name returns the name of the constraint. + // This name is used in the constraint matching. + Name() string + + // Execute executes the constraint. + // It returns true if the constraint is matched and right. + // param is the parameter value to check. + // args are the constraint arguments. + Execute(param string, args ...string) bool +} +``` + +You can check the example below: + +```go +type UlidConstraint struct { + fiber.CustomConstraint +} + +func (*UlidConstraint) Name() string { + return "ulid" +} + +func (*UlidConstraint) Execute(param string, args ...string) bool { + _, err := ulid.Parse(param) + return err == nil +} + +func main() { + app := fiber.New() + app.RegisterCustomConstraint(&UlidConstraint{}) + + app.Get("/login/:id", func(c fiber.Ctx) error { + return c.SendString("...") + }) + + app.Listen(":3000") + + // /login/01HK7H9ZE5BFMK348CPYP14S0Z -> 200 + // /login/12345 -> 404 +} +``` + ## Middleware Functions that are designed to make changes to the request or response are called **middleware functions**. The [Next](../api/ctx.md#next) is a **Fiber** router function, when called, executes the **next** function that **matches** the current route. diff --git a/go.sum b/go.sum index 70a979163d..47a38d262f 100644 --- a/go.sum +++ b/go.sum @@ -64,4 +64,4 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= \ No newline at end of file diff --git a/path.go b/path.go index bfcfe68e6f..a51b31562e 100644 --- a/path.go +++ b/path.go @@ -65,9 +65,24 @@ const ( type TypeConstraint int16 type Constraint struct { - ID TypeConstraint - RegexCompiler *regexp.Regexp - Data []string + ID TypeConstraint + RegexCompiler *regexp.Regexp + Data []string + Name string + customConstraints []CustomConstraint +} + +// CustomConstraint is an interface for custom constraints +type CustomConstraint interface { + // Name returns the name of the constraint. + // This name is used in the constraint matching. + Name() string + + // Execute executes the constraint. + // It returns true if the constraint is matched and right. + // param is the parameter value to check. + // args are the constraint arguments. + Execute(param string, args ...string) bool } const ( @@ -175,15 +190,14 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool { // parseRoute analyzes the route and divides it into segments for constant areas and parameters, // this information is needed later when assigning the requests to the declared routes -func parseRoute(pattern string) routeParser { +func parseRoute(pattern string, customConstraints ...CustomConstraint) routeParser { parser := routeParser{} - part := "" for len(pattern) > 0 { nextParamPosition := findNextParamPosition(pattern) // handle the parameter part if nextParamPosition == 0 { - processedPart, seg := parser.analyseParameterPart(pattern) + processedPart, seg := parser.analyseParameterPart(pattern, customConstraints...) parser.params, parser.segs, part = append(parser.params, seg.ParamName), append(parser.segs, seg), processedPart } else { processedPart, seg := parser.analyseConstantPart(pattern, nextParamPosition) @@ -284,7 +298,7 @@ func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) ( } // analyseParameterPart find the parameter end and create the route segment -func (routeParser *routeParser) analyseParameterPart(pattern string) (string, *routeSegment) { +func (routeParser *routeParser) analyseParameterPart(pattern string, customConstraints ...CustomConstraint) (string, *routeSegment) { isWildCard := pattern[0] == wildcardParam isPlusParam := pattern[0] == plusParam @@ -332,7 +346,9 @@ func (routeParser *routeParser) analyseParameterPart(pattern string) (string, *r // Assign constraint if start != -1 && end != -1 { constraint := &Constraint{ - ID: getParamConstraintType(c[:start]), + ID: getParamConstraintType(c[:start]), + Name: c[:start], + customConstraints: customConstraints, } // remove escapes from data @@ -355,8 +371,10 @@ func (routeParser *routeParser) analyseParameterPart(pattern string) (string, *r constraints = append(constraints, constraint) } else { constraints = append(constraints, &Constraint{ - ID: getParamConstraintType(c), - Data: []string{}, + ID: getParamConstraintType(c), + Data: []string{}, + Name: c, + customConstraints: customConstraints, }) } } @@ -666,7 +684,11 @@ func (c *Constraint) CheckConstraint(param string) bool { // check constraints switch c.ID { case noConstraint: - // Nothing to check + for _, cc := range c.customConstraints { + if cc.Name() == c.Name { + return cc.Execute(param, c.Data...) + } + } case intConstraint: _, err = strconv.Atoi(param) case boolConstraint: diff --git a/router.go b/router.go index f25e98c565..b210b5a56d 100644 --- a/router.go +++ b/router.go @@ -260,7 +260,7 @@ func (app *App) addPrefixToRoute(prefix string, route *Route) *Route { route.Path = prefixedPath route.path = RemoveEscapeChar(prettyPath) - route.routeParser = parseRoute(prettyPath) + route.routeParser = parseRoute(prettyPath, app.customConstraints...) route.root = false route.star = false @@ -335,8 +335,8 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler // Is path a root slash? isRoot := pathPretty == "/" // Parse path parameters - parsedRaw := parseRoute(pathRaw) - parsedPretty := parseRoute(pathPretty) + parsedRaw := parseRoute(pathRaw, app.customConstraints...) + parsedPretty := parseRoute(pathPretty, app.customConstraints...) // Create route metadata without pointer route := Route{