diff --git a/build.go b/build.go index 27c71b1b..e831a210 100644 --- a/build.go +++ b/build.go @@ -16,7 +16,7 @@ package openssl -// #cgo pkg-config: libssl +// #cgo pkg-config: libssl libcrypto // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN // #cgo darwin CFLAGS: -Wno-deprecated-declarations import "C" diff --git a/conn.go b/conn.go index 9837ce3a..ee69974c 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build cgo // +build cgo package openssl @@ -31,6 +32,9 @@ package openssl // const char * SSL_get_cipher_name_not_a_macro(const SSL *ssl) { // return SSL_get_cipher_name(ssl); // } +// int SSL_version_not_a_macro(const SSL *ssl) { +// return SSL_version(ssl); +// } import "C" import ( @@ -476,6 +480,43 @@ func (c *Conn) Read(b []byte) (n int, err error) { return 0, err } +func (c *Conn) peek(b []byte) (int, func() error) { + if len(b) == 0 { + return 0, nil + } + c.mtx.Lock() + defer c.mtx.Unlock() + if c.is_shutdown { + return 0, func() error { return io.EOF } + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + rv, errno := C.SSL_peek(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b))) + if rv > 0 { + return int(rv), nil + } + return 0, c.getErrorHandler(rv, errno) +} + +func (c *Conn) Peek(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + err = tryAgain + for err == tryAgain { + n, errcb := c.peek(b) + err = c.handleError(errcb) + if err == nil { + go c.flushOutputBuffer() + return n, nil + } + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + } + return 0, err +} + func (c *Conn) write(b []byte) (int, func() error) { if len(b) == 0 { return 0, nil @@ -548,6 +589,11 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } +func (c *Conn) SetCtx(ctx *Ctx) { + c.ctx = ctx + C.SSL_set_SSL_CTX(c.ssl, ctx.ctx) +} + func (c *Conn) UnderlyingConn() net.Conn { return c.conn } @@ -566,3 +612,11 @@ func (c *Conn) SetTlsExtHostName(name string) error { func (c *Conn) VerifyResult() VerifyResult { return VerifyResult(C.SSL_get_verify_result(c.ssl)) } + +func (c *Conn) GetServerName() string { + return C.GoString(C.SSL_get_servername(c.ssl, C.TLSEXT_NAMETYPE_host_name)) +} + +func (c *Conn) Version() int { + return int(C.SSL_version_not_a_macro(c.ssl)) +} diff --git a/ctx.go b/ctx.go index 7db505ec..1c0c1793 100644 --- a/ctx.go +++ b/ctx.go @@ -70,6 +70,40 @@ static long SSL_CTX_set_tmp_ecdh_not_a_macro(SSL_CTX* ctx, EC_KEY *key) { return SSL_CTX_set_tmp_ecdh(ctx, key); } +static long SSL_CTX_set_tlsext_servername_callback_not_a_macro(SSL_CTX* ctx, void (*fp)()) { + return SSL_CTX_set_tlsext_servername_callback(ctx, fp); +} + +typedef struct TlsServernameData { + void *go_ctx; + SSL_CTX *ctx; + void *arg; +} TlsServernameData; + +static TlsServernameData* new_TlsServernameData() { + return calloc(1, sizeof(TlsServernameData)); +} + +//UNUSED: openssl doesn't have a way to unset SNI callback or arg. So we just leak whatever +//the function above allocates +//static void del_TlsServernameData(TlsServernameData *tsd) { +// free(tds); +//} + +extern int callServerNameCb(SSL* ssl, int ad, void* arg); + +static int call_go_servername(SSL* ssl, int ad, void* arg) { + return callServerNameCb(ssl, ad, arg); +} + +static int servername_gateway(TlsServernameData* cw) { + SSL_CTX* ctx = cw->ctx; + //TODO: figure out what to do with return codes. The first isn't 0 + SSL_CTX_set_tlsext_servername_callback(ctx, call_go_servername); + SSL_CTX_set_tlsext_servername_arg(ctx, cw); + return 0; +} + #ifndef SSL_MODE_RELEASE_BUFFERS #define SSL_MODE_RELEASE_BUFFERS 0 #endif @@ -117,11 +151,13 @@ var ( ) type Ctx struct { - ctx *C.SSL_CTX - cert *Certificate - chain []*Certificate - key PrivateKey - verify_cb VerifyCallback + ctx *C.SSL_CTX + cert *Certificate + chain []*Certificate + key PrivateKey + verify_cb VerifyCallback + servername_cb ServerNameCallback + ted *C.TlsServernameData } //export get_ssl_ctx_idx @@ -605,3 +641,35 @@ func (c *Ctx) SessSetCacheSize(t int) int { func (c *Ctx) SessGetCacheSize() int { return int(C.SSL_CTX_sess_get_cache_size_not_a_macro(c.ctx)) } + +// Set SSL_CTX_set_tlsext_servername_callback +// https://www.openssl.org/docs/manmaster/ssl/??? +type ServerNameCallback func(ssl Conn, ad int, arg unsafe.Pointer) int + +//export callServerNameCb +func callServerNameCb(ssl *C.SSL, ad C.int, arg unsafe.Pointer) C.int { + var ted *C.TlsServernameData = (*C.TlsServernameData)(arg) + goCtx := (*Ctx)(ted.go_ctx) + + //setup a dummy Conn so we can associate a SSL_CTX from user callback + conn := Conn{ + ssl: ssl, + ctx: goCtx, + } + ret := goCtx.servername_cb(conn, int(ad), ted.arg) + return C.int(ret) +} + +func (c *Ctx) SetTlsExtServerNameCallback(cb func(ssl Conn, ad int, arg unsafe.Pointer) int, + arg unsafe.Pointer) int { + c.servername_cb = cb + ted := C.new_TlsServernameData() + if ted == nil { + return 1 + } + ted.go_ctx = unsafe.Pointer(c) + ted.ctx = c.ctx + ted.arg = arg + c.ted = ted + return int(C.servername_gateway(c.ted)) +} diff --git a/tls_ext_test.go b/tls_ext_test.go new file mode 100644 index 00000000..e6015da1 --- /dev/null +++ b/tls_ext_test.go @@ -0,0 +1,120 @@ +// Copyright (C) 2014 Space Monkey, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openssl + +import ( + "bytes" + "io" + "sync" + "testing" + "unsafe" +) + +var gFoundServerName bool = false +var gServerName string +var gCallbackData string = "some callback data" + +func passThroughServername() func(ssl Conn, ad int, arg unsafe.Pointer) int { + x := func(ssl Conn, ad int, arg unsafe.Pointer) int { + cbData := (*string)(arg) + if *cbData != gCallbackData { //we should getthe callback data we set on the CTX + return 1 + } + name := ssl.GetServerName() + if name == gServerName { + gFoundServerName = true + //here we'd normally do soemthing like get a CTX for the specific server name and + //set it on the conn. + } else { + gFoundServerName = false + } + return 0 + } + return x +} + +func TestTLSExtSNI(t *testing.T) { + //setup SNI On the CTX + server_conn, client_conn := NetPipe(t) + defer server_conn.Close() + defer client_conn.Close() + + server, client := OpenSSLConstructor(t, server_conn, client_conn) + cconn := client.(*Conn) + sconn := server.(*Conn) + ctx := (*sconn).ctx + //setup SNI On the CTX + rc := ctx.SetTlsExtServerNameCallback(passThroughServername(), unsafe.Pointer(&gCallbackData)) + if rc != 0 { + t.Fatal("Expected 0 from ctx.SetTlsExtServerNameCallback, but got %d", rc) + } + data := "first test string\n" + host := "test-host" + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + gServerName = host + err := cconn.SetTlsExtHostName(host) + if err != nil { + t.Fatal(err) + } + + err = client.Handshake() + if err != nil { + t.Fatal(err) + } + + _, err = io.Copy(client, bytes.NewReader([]byte(data))) + if err != nil { + t.Fatal(err) + } + + err = client.Close() + if err != nil { + t.Fatal(err) + } + }() + go func() { + defer wg.Done() + + err := server.Handshake() + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(make([]byte, 0, len(data))) + _, err = io.CopyN(buf, server, int64(len(data))) + if err != nil { + t.Fatal(err) + } + if string(buf.Bytes()) != data { + t.Fatal("mismatched data") + } + + err = server.Close() + if err != nil { + t.Fatal(err) + } + }() + wg.Wait() + if gFoundServerName == false { + t.Fatal("Expected gFoundServerName to be set to true") + } + if gServerName != host { + t.Fatal("Expected gServerName to be '%s', but it was '%s'", host, gServerName) + } +}