xref: /openbsd-src/regress/lib/libtls/gotls/tls_test.go (revision f2da64fbbbf1b03f09f390ab01267c93dfd77c4c)
1package tls
2
3import (
4	"encoding/pem"
5	"fmt"
6	"io/ioutil"
7	"net/http"
8	"net/http/httptest"
9	"net/url"
10	"os"
11	"strings"
12	"testing"
13	"time"
14)
15
16const (
17	httpContent = "Hello, TLS!"
18
19	certHash = "SHA256:448f628a8a65aa18560e53a80c53acb38c51b427df0334082349141147dc9bf6"
20)
21
22var (
23	certNotBefore = time.Unix(0, 0)
24	certNotAfter = certNotBefore.Add(1000000 * time.Hour)
25)
26
27// createCAFile writes a PEM encoded version of the certificate out to a
28// temporary file, for use by libtls.
29func createCAFile(cert []byte) (string, error) {
30	f, err := ioutil.TempFile("", "tls")
31	if err != nil {
32		return "", fmt.Errorf("failed to create file: %v", err)
33	}
34	defer f.Close()
35	block := &pem.Block{
36		Type:  "CERTIFICATE",
37		Bytes: cert,
38	}
39	if err := pem.Encode(f, block); err != nil {
40		return "", fmt.Errorf("failed to encode certificate: %v", err)
41	}
42	return f.Name(), nil
43}
44
45func newTestServer() (*httptest.Server, *url.URL, string, error) {
46	ts := httptest.NewTLSServer(
47		http.HandlerFunc(
48			func(w http.ResponseWriter, r *http.Request) {
49				fmt.Fprintln(w, httpContent)
50			},
51		),
52	)
53
54	u, err := url.Parse(ts.URL)
55	if err != nil {
56		return nil, nil, "", fmt.Errorf("failed to parse URL %q: %v", ts.URL, err)
57	}
58
59	caFile, err := createCAFile(ts.TLS.Certificates[0].Certificate[0])
60	if err != nil {
61		return nil, nil, "", fmt.Errorf("failed to create CA file: %v", err)
62	}
63
64	return ts, u, caFile, nil
65}
66
67func TestTLSBasic(t *testing.T) {
68	ts, u, caFile, err := newTestServer()
69	if err != nil {
70		t.Fatalf("Failed to start test server: %v", err)
71	}
72	defer os.Remove(caFile)
73	defer ts.Close()
74
75	if err := Init(); err != nil {
76		t.Fatal(err)
77	}
78
79	cfg, err := NewConfig()
80	if err != nil {
81		t.Fatal(err)
82	}
83	defer cfg.Free()
84	cfg.SetCAFile(caFile)
85
86	tls, err := NewClient(cfg)
87	if err != nil {
88		t.Fatal(err)
89	}
90	defer tls.Free()
91
92	t.Logf("Connecting to %s", u.Host)
93
94	if err := tls.Connect(u.Host, ""); err != nil {
95		t.Fatal(err)
96	}
97	defer func() {
98		if err := tls.Close(); err != nil {
99			t.Fatalf("Close failed: %v", err)
100		}
101	}()
102
103	n, err := tls.Write([]byte("GET / HTTP/1.0\n\n"))
104	if err != nil {
105		t.Fatal(err)
106	}
107	t.Logf("Wrote %d bytes...", n)
108
109	buf := make([]byte, 1024)
110	n, err = tls.Read(buf)
111	if err != nil {
112		t.Fatal(err)
113	}
114	t.Logf("Read %d bytes...", n)
115
116	if !strings.Contains(string(buf), httpContent) {
117		t.Errorf("Response does not contain %q", httpContent)
118	}
119}
120
121func TestTLSSingleByteReadWrite(t *testing.T) {
122	ts, u, caFile, err := newTestServer()
123	if err != nil {
124		t.Fatalf("Failed to start test server: %v", err)
125	}
126	defer os.Remove(caFile)
127	defer ts.Close()
128
129	if err := Init(); err != nil {
130		t.Fatal(err)
131	}
132
133	cfg, err := NewConfig()
134	if err != nil {
135		t.Fatal(err)
136	}
137	defer cfg.Free()
138	cfg.SetCAFile(caFile)
139
140	tls, err := NewClient(cfg)
141	if err != nil {
142		t.Fatal(err)
143	}
144	defer tls.Free()
145
146	t.Logf("Connecting to %s", u.Host)
147
148	if err := tls.Connect(u.Host, ""); err != nil {
149		t.Fatal(err)
150	}
151	defer func() {
152		if err := tls.Close(); err != nil {
153			t.Fatalf("Close failed: %v", err)
154		}
155	}()
156
157	for _, b := range []byte("GET / HTTP/1.0\n\n") {
158		n, err := tls.Write([]byte{b})
159		if err != nil {
160			t.Fatal(err)
161		}
162		if n != 1 {
163			t.Fatalf("Wrote byte %v, got length %d, want 1", b, n)
164		}
165	}
166
167	var body []byte
168	for {
169		buf := make([]byte, 1)
170		n, err := tls.Read(buf)
171		if err != nil {
172			t.Fatal(err)
173		}
174		if n == 0 {
175			break
176		}
177		if n != 1 {
178			t.Fatalf("Read single byte, got length %d, want 1", n)
179		}
180		body = append(body, buf...)
181	}
182
183	if !strings.Contains(string(body), httpContent) {
184		t.Errorf("Response does not contain %q", httpContent)
185	}
186}
187
188func TestTLSInfo(t *testing.T) {
189	ts, u, caFile, err := newTestServer()
190	if err != nil {
191		t.Fatalf("Failed to start test server: %v", err)
192	}
193	defer os.Remove(caFile)
194	defer ts.Close()
195
196	if err := Init(); err != nil {
197		t.Fatal(err)
198	}
199
200	cfg, err := NewConfig()
201	if err != nil {
202		t.Fatal(err)
203	}
204	defer cfg.Free()
205	cfg.SetCAFile(caFile)
206
207	tls, err := NewClient(cfg)
208	if err != nil {
209		t.Fatal(err)
210	}
211	defer tls.Free()
212
213	t.Logf("Connecting to %s", u.Host)
214
215	if err := tls.Connect(u.Host, ""); err != nil {
216		t.Fatal(err)
217	}
218	defer func() {
219		if err := tls.Close(); err != nil {
220			t.Fatalf("Close failed: %v", err)
221		}
222	}()
223
224	// All of these should fail since the handshake has not completed.
225	if _, err := tls.ConnVersion(); err == nil {
226		t.Error("ConnVersion() return nil error, want error")
227	}
228	if _, err := tls.ConnCipher(); err == nil {
229		t.Error("ConnCipher() return nil error, want error")
230	}
231
232	if got, want := tls.PeerCertProvided(), false; got != want {
233		t.Errorf("PeerCertProvided() = %v, want %v", got, want)
234	}
235	for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
236		if got, want := tls.PeerCertContainsName(name), false; got != want {
237			t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
238		}
239	}
240
241	if _, err := tls.PeerCertIssuer(); err == nil {
242		t.Error("PeerCertIssuer() returned nil error, want error")
243	}
244	if _, err := tls.PeerCertSubject(); err == nil {
245		t.Error("PeerCertSubject() returned nil error, want error")
246	}
247	if _, err := tls.PeerCertHash(); err == nil {
248		t.Error("PeerCertHash() returned nil error, want error")
249	}
250	if _, err := tls.PeerCertNotBefore(); err == nil {
251		t.Error("PeerCertNotBefore() returned nil error, want error")
252	}
253	if _, err := tls.PeerCertNotAfter(); err == nil {
254		t.Error("PeerCertNotAfter() returned nil error, want error")
255	}
256
257	// Complete the handshake...
258	if err := tls.Handshake(); err != nil {
259		t.Fatalf("Handshake failed: %v", err)
260	}
261
262	if version, err := tls.ConnVersion(); err != nil {
263		t.Errorf("ConnVersion() return error: %v", err)
264	} else {
265		t.Logf("Protocol version: %v", version)
266	}
267	if cipher, err := tls.ConnCipher(); err != nil {
268		t.Errorf("ConnCipher() return error: %v", err)
269	} else {
270		t.Logf("Cipher: %v", cipher)
271	}
272
273	if got, want := tls.PeerCertProvided(), true; got != want {
274		t.Errorf("PeerCertProvided() = %v, want %v", got, want)
275	}
276	for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
277		if got, want := tls.PeerCertContainsName(name), true; got != want {
278			t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
279		}
280	}
281
282	if issuer, err := tls.PeerCertIssuer(); err != nil {
283		t.Errorf("PeerCertIssuer() returned error: %v", err)
284	} else {
285		t.Logf("Issuer: %v", issuer)
286	}
287	if subject, err := tls.PeerCertSubject(); err != nil {
288		t.Errorf("PeerCertSubject() returned error: %v", err)
289	} else {
290		t.Logf("Subject: %v", subject)
291	}
292	if hash, err := tls.PeerCertHash(); err != nil {
293		t.Errorf("PeerCertHash() returned error: %v", err)
294	} else if hash != certHash {
295		t.Errorf("Got cert hash %q, want %q", hash, certHash)
296	} else {
297		t.Logf("Hash: %v", hash)
298	}
299	if notBefore, err := tls.PeerCertNotBefore(); err != nil {
300		t.Errorf("PeerCertNotBefore() returned error: %v", err)
301	} else if !certNotBefore.Equal(notBefore) {
302		t.Errorf("Got cert notBefore %v, want %v", notBefore.UTC(), certNotBefore.UTC())
303	} else {
304		t.Logf("NotBefore: %v", notBefore.UTC())
305	}
306	if notAfter, err := tls.PeerCertNotAfter(); err != nil {
307		t.Errorf("PeerCertNotAfter() returned error: %v", err)
308	} else if !certNotAfter.Equal(notAfter) {
309		t.Errorf("Got cert notAfter %v, want %v", notAfter.UTC(), certNotAfter.UTC())
310	} else {
311		t.Logf("NotAfter: %v", notAfter.UTC())
312	}
313}
314