|
| 1 | +/* |
| 2 | +Copyright IBM Corp. All Rights Reserved. |
| 3 | +
|
| 4 | +SPDX-License-Identifier: Apache-2.0 |
| 5 | +*/ |
| 6 | + |
| 7 | +package comm |
| 8 | + |
| 9 | +import ( |
| 10 | + "context" |
| 11 | + "crypto/tls" |
| 12 | + "crypto/x509" |
| 13 | + "time" |
| 14 | + |
| 15 | + "github.com/pkg/errors" |
| 16 | + "google.golang.org/grpc" |
| 17 | + "google.golang.org/grpc/credentials" |
| 18 | + "google.golang.org/grpc/keepalive" |
| 19 | +) |
| 20 | + |
| 21 | +type GRPCClient interface { |
| 22 | + // Certificate returns the tls.Certificate used to make TLS connections |
| 23 | + // when client certificates are required by the server |
| 24 | + Certificate() tls.Certificate |
| 25 | + // TLSEnabled is a flag indicating whether to use TLS for client |
| 26 | + // connections |
| 27 | + TLSEnabled() bool |
| 28 | + // MutualTLSRequired is a flag indicating whether the client |
| 29 | + // must send a certificate when making TLS connections |
| 30 | + MutualTLSRequired() bool |
| 31 | + // SetMaxRecvMsgSize sets the maximum message size the client can receive |
| 32 | + SetMaxRecvMsgSize(size int) |
| 33 | + // SetMaxSendMsgSize sets the maximum message size the client can send |
| 34 | + SetMaxSendMsgSize(size int) |
| 35 | + // SetServerRootCAs sets the list of authorities used to verify server |
| 36 | + // certificates based on a list of PEM-encoded X509 certificate authorities |
| 37 | + SetServerRootCAs(clientRoots [][]byte) error |
| 38 | + // NewConnection returns a grpc.ClientConn for the target address |
| 39 | + NewConnection(address string) (*grpc.ClientConn, error) |
| 40 | +} |
| 41 | + |
| 42 | +type grpcClient struct { |
| 43 | + // Set of PEM-encoded X509 certificate authorities used to populate |
| 44 | + // the tlsConfig.RootCAs indexed by subject |
| 45 | + serverRootCAs map[string]*x509.Certificate |
| 46 | + // TLS configuration used by the grpc.ClientConn |
| 47 | + tlsConfig *tls.Config |
| 48 | + // Flag indicating whether TLS is enabled |
| 49 | + tlsEnabled bool |
| 50 | + // Flag indicating whether a client certificate is required |
| 51 | + mutualTLSRequired bool |
| 52 | + // Options for setting up new connections |
| 53 | + dialOpts []grpc.DialOption |
| 54 | + // Duration for which to block while established a new connection |
| 55 | + timeout time.Duration |
| 56 | + // Maximum message size the client can receive |
| 57 | + maxRecvMsgSize int |
| 58 | + // Maximum message size the client can send |
| 59 | + maxSendMsgSize int |
| 60 | +} |
| 61 | + |
| 62 | +// NewGRPCClient creates a new implementation of GRPCClient given an address |
| 63 | +// and client configuration |
| 64 | +func NewGRPCClient(config ClientConfig) (GRPCClient, error) { |
| 65 | + client := &grpcClient{ |
| 66 | + tlsEnabled: false, |
| 67 | + mutualTLSRequired: false} |
| 68 | + |
| 69 | + // parse secure options |
| 70 | + err := client.parseSecureOptions(config.SecOpts) |
| 71 | + if err != nil { |
| 72 | + return client, err |
| 73 | + } |
| 74 | + |
| 75 | + // keepalive options |
| 76 | + var kap keepalive.ClientParameters |
| 77 | + if config.KaOpts != nil { |
| 78 | + kap = keepalive.ClientParameters{ |
| 79 | + Time: config.KaOpts.ClientInterval, |
| 80 | + Timeout: config.KaOpts.ClientTimeout} |
| 81 | + } else { |
| 82 | + // use defaults |
| 83 | + kap = keepalive.ClientParameters{ |
| 84 | + Time: keepaliveOptions.ClientInterval, |
| 85 | + Timeout: keepaliveOptions.ClientTimeout} |
| 86 | + } |
| 87 | + kap.PermitWithoutStream = true |
| 88 | + // set keepalive and blocking |
| 89 | + client.dialOpts = append(client.dialOpts, grpc.WithKeepaliveParams(kap), |
| 90 | + grpc.WithBlock()) |
| 91 | + client.timeout = config.Timeout |
| 92 | + // set send/recv message size to package defaults |
| 93 | + client.maxRecvMsgSize = maxRecvMsgSize |
| 94 | + client.maxSendMsgSize = maxSendMsgSize |
| 95 | + |
| 96 | + return client, nil |
| 97 | +} |
| 98 | + |
| 99 | +func (client *grpcClient) parseSecureOptions(opts *SecureOptions) error { |
| 100 | + |
| 101 | + if opts == nil || !opts.UseTLS { |
| 102 | + return nil |
| 103 | + } |
| 104 | + client.tlsEnabled = true |
| 105 | + client.tlsConfig = &tls.Config{ |
| 106 | + MinVersion: tls.VersionTLS12} // TLS 1.2 only |
| 107 | + if len(opts.ServerRootCAs) > 0 { |
| 108 | + client.tlsConfig.RootCAs = x509.NewCertPool() |
| 109 | + for _, certBytes := range opts.ServerRootCAs { |
| 110 | + err := AddPemToCertPool(certBytes, client.tlsConfig.RootCAs) |
| 111 | + if err != nil { |
| 112 | + commLogger.Debugf("error adding root certificate: %v", err) |
| 113 | + return errors.WithMessage(err, |
| 114 | + "error adding root certificate") |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + if opts.RequireClientCert { |
| 119 | + client.mutualTLSRequired = true |
| 120 | + // make sure we have both Key and Certificate |
| 121 | + if opts.Key != nil && |
| 122 | + opts.Certificate != nil { |
| 123 | + cert, err := tls.X509KeyPair(opts.Certificate, |
| 124 | + opts.Key) |
| 125 | + if err != nil { |
| 126 | + return errors.WithMessage(err, "failed to "+ |
| 127 | + "load client certificate") |
| 128 | + } |
| 129 | + client.tlsConfig.Certificates = append( |
| 130 | + client.tlsConfig.Certificates, cert) |
| 131 | + } else { |
| 132 | + return errors.New("both Key and Certificate " + |
| 133 | + "are required when using mutual TLS") |
| 134 | + } |
| 135 | + } |
| 136 | + return nil |
| 137 | +} |
| 138 | + |
| 139 | +// Certificate returns the tls.Certificate used to make TLS connections |
| 140 | +// when client certificates are required by the server |
| 141 | +func (client *grpcClient) Certificate() tls.Certificate { |
| 142 | + cert := tls.Certificate{} |
| 143 | + if client.tlsConfig != nil && len(client.tlsConfig.Certificates) > 0 { |
| 144 | + cert = client.tlsConfig.Certificates[0] |
| 145 | + } |
| 146 | + return cert |
| 147 | +} |
| 148 | + |
| 149 | +// TLSEnabled is a flag indicating whether to use TLS for client |
| 150 | +// connections |
| 151 | +func (client *grpcClient) TLSEnabled() bool { |
| 152 | + return client.tlsEnabled |
| 153 | +} |
| 154 | + |
| 155 | +// MutualTLSRequired is a flag indicating whether the client |
| 156 | +// must send a certificate when making TLS connections |
| 157 | +func (client *grpcClient) MutualTLSRequired() bool { |
| 158 | + return client.mutualTLSRequired |
| 159 | +} |
| 160 | + |
| 161 | +// SetMaxRecvMsgSize sets the maximum message size the client can receive |
| 162 | +func (client *grpcClient) SetMaxRecvMsgSize(size int) { |
| 163 | + client.maxRecvMsgSize = size |
| 164 | +} |
| 165 | + |
| 166 | +// SetMaxSendMsgSize sets the maximum message size the client can send |
| 167 | +func (client *grpcClient) SetMaxSendMsgSize(size int) { |
| 168 | + client.maxSendMsgSize = size |
| 169 | +} |
| 170 | + |
| 171 | +// SetServerRootCAs sets the list of authorities used to verify server |
| 172 | +// certificates based on a list of PEM-encoded X509 certificate authorities |
| 173 | +func (client *grpcClient) SetServerRootCAs(serverRoots [][]byte) error { |
| 174 | + |
| 175 | + // NOTE: if no serverRoots are specified, the current cert pool will be |
| 176 | + // replaced with an empty one |
| 177 | + certPool := x509.NewCertPool() |
| 178 | + for _, root := range serverRoots { |
| 179 | + err := AddPemToCertPool(root, certPool) |
| 180 | + if err != nil { |
| 181 | + return errors.WithMessage(err, "error adding root certificate") |
| 182 | + } |
| 183 | + } |
| 184 | + client.tlsConfig.RootCAs = certPool |
| 185 | + return nil |
| 186 | +} |
| 187 | + |
| 188 | +// NewConnection returns a grpc.ClientConn for the target address |
| 189 | +func (client *grpcClient) NewConnection(address string) ( |
| 190 | + *grpc.ClientConn, error) { |
| 191 | + |
| 192 | + var dialOpts []grpc.DialOption |
| 193 | + dialOpts = append(dialOpts, client.dialOpts...) |
| 194 | + |
| 195 | + // set transport credentials and max send/recv message sizes |
| 196 | + // immediately before creating a connection in order to allow |
| 197 | + // SetServerRootCAs / SetMaxRecvMsgSize / SetMaxSendMsgSize |
| 198 | + // to take effect on a per connection basis |
| 199 | + if client.tlsConfig != nil { |
| 200 | + dialOpts = append(dialOpts, |
| 201 | + grpc.WithTransportCredentials( |
| 202 | + credentials.NewTLS(client.tlsConfig))) |
| 203 | + } else { |
| 204 | + dialOpts = append(dialOpts, grpc.WithInsecure()) |
| 205 | + } |
| 206 | + dialOpts = append(dialOpts, grpc.WithDefaultCallOptions( |
| 207 | + grpc.MaxCallRecvMsgSize(client.maxRecvMsgSize), |
| 208 | + grpc.MaxCallSendMsgSize(client.maxSendMsgSize))) |
| 209 | + |
| 210 | + ctx, cancel := context.WithTimeout(context.Background(), client.timeout) |
| 211 | + defer cancel() |
| 212 | + conn, err := grpc.DialContext(ctx, address, dialOpts...) |
| 213 | + if err != nil { |
| 214 | + return nil, errors.WithMessage(errors.WithStack(err), |
| 215 | + "failed to create new connection") |
| 216 | + } |
| 217 | + return conn, nil |
| 218 | +} |
0 commit comments