1package tls 2 3import ( 4 "encoding/pem" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "net/http/httptest" 9 "net/url" 10 "os" 11 "strings" 12 "testing" 13 "time" 14) 15 16const ( 17 httpContent = "Hello, TLS!" 18 19 certHash = "SHA256:448f628a8a65aa18560e53a80c53acb38c51b427df0334082349141147dc9bf6" 20) 21 22var ( 23 certNotBefore = time.Unix(0, 0) 24 certNotAfter = certNotBefore.Add(1000000 * time.Hour) 25) 26 27// createCAFile writes a PEM encoded version of the certificate out to a 28// temporary file, for use by libtls. 29func createCAFile(cert []byte) (string, error) { 30 f, err := ioutil.TempFile("", "tls") 31 if err != nil { 32 return "", fmt.Errorf("failed to create file: %v", err) 33 } 34 defer f.Close() 35 block := &pem.Block{ 36 Type: "CERTIFICATE", 37 Bytes: cert, 38 } 39 if err := pem.Encode(f, block); err != nil { 40 return "", fmt.Errorf("failed to encode certificate: %v", err) 41 } 42 return f.Name(), nil 43} 44 45func newTestServer() (*httptest.Server, *url.URL, string, error) { 46 ts := httptest.NewTLSServer( 47 http.HandlerFunc( 48 func(w http.ResponseWriter, r *http.Request) { 49 fmt.Fprintln(w, httpContent) 50 }, 51 ), 52 ) 53 54 u, err := url.Parse(ts.URL) 55 if err != nil { 56 return nil, nil, "", fmt.Errorf("failed to parse URL %q: %v", ts.URL, err) 57 } 58 59 caFile, err := createCAFile(ts.TLS.Certificates[0].Certificate[0]) 60 if err != nil { 61 return nil, nil, "", fmt.Errorf("failed to create CA file: %v", err) 62 } 63 64 return ts, u, caFile, nil 65} 66 67func TestTLSBasic(t *testing.T) { 68 ts, u, caFile, err := newTestServer() 69 if err != nil { 70 t.Fatalf("Failed to start test server: %v", err) 71 } 72 defer os.Remove(caFile) 73 defer ts.Close() 74 75 if err := Init(); err != nil { 76 t.Fatal(err) 77 } 78 79 cfg, err := NewConfig() 80 if err != nil { 81 t.Fatal(err) 82 } 83 defer cfg.Free() 84 cfg.SetCAFile(caFile) 85 86 tls, err := NewClient(cfg) 87 if err != nil { 88 t.Fatal(err) 89 } 90 defer tls.Free() 91 92 t.Logf("Connecting to %s", u.Host) 93 94 if err := tls.Connect(u.Host, ""); err != nil { 95 t.Fatal(err) 96 } 97 defer func() { 98 if err := tls.Close(); err != nil { 99 t.Fatalf("Close failed: %v", err) 100 } 101 }() 102 103 n, err := tls.Write([]byte("GET / HTTP/1.0\n\n")) 104 if err != nil { 105 t.Fatal(err) 106 } 107 t.Logf("Wrote %d bytes...", n) 108 109 buf := make([]byte, 1024) 110 n, err = tls.Read(buf) 111 if err != nil { 112 t.Fatal(err) 113 } 114 t.Logf("Read %d bytes...", n) 115 116 if !strings.Contains(string(buf), httpContent) { 117 t.Errorf("Response does not contain %q", httpContent) 118 } 119} 120 121func TestTLSSingleByteReadWrite(t *testing.T) { 122 ts, u, caFile, err := newTestServer() 123 if err != nil { 124 t.Fatalf("Failed to start test server: %v", err) 125 } 126 defer os.Remove(caFile) 127 defer ts.Close() 128 129 if err := Init(); err != nil { 130 t.Fatal(err) 131 } 132 133 cfg, err := NewConfig() 134 if err != nil { 135 t.Fatal(err) 136 } 137 defer cfg.Free() 138 cfg.SetCAFile(caFile) 139 140 tls, err := NewClient(cfg) 141 if err != nil { 142 t.Fatal(err) 143 } 144 defer tls.Free() 145 146 t.Logf("Connecting to %s", u.Host) 147 148 if err := tls.Connect(u.Host, ""); err != nil { 149 t.Fatal(err) 150 } 151 defer func() { 152 if err := tls.Close(); err != nil { 153 t.Fatalf("Close failed: %v", err) 154 } 155 }() 156 157 for _, b := range []byte("GET / HTTP/1.0\n\n") { 158 n, err := tls.Write([]byte{b}) 159 if err != nil { 160 t.Fatal(err) 161 } 162 if n != 1 { 163 t.Fatalf("Wrote byte %v, got length %d, want 1", b, n) 164 } 165 } 166 167 var body []byte 168 for { 169 buf := make([]byte, 1) 170 n, err := tls.Read(buf) 171 if err != nil { 172 t.Fatal(err) 173 } 174 if n == 0 { 175 break 176 } 177 if n != 1 { 178 t.Fatalf("Read single byte, got length %d, want 1", n) 179 } 180 body = append(body, buf...) 181 } 182 183 if !strings.Contains(string(body), httpContent) { 184 t.Errorf("Response does not contain %q", httpContent) 185 } 186} 187 188func TestTLSInfo(t *testing.T) { 189 ts, u, caFile, err := newTestServer() 190 if err != nil { 191 t.Fatalf("Failed to start test server: %v", err) 192 } 193 defer os.Remove(caFile) 194 defer ts.Close() 195 196 if err := Init(); err != nil { 197 t.Fatal(err) 198 } 199 200 cfg, err := NewConfig() 201 if err != nil { 202 t.Fatal(err) 203 } 204 defer cfg.Free() 205 cfg.SetCAFile(caFile) 206 207 tls, err := NewClient(cfg) 208 if err != nil { 209 t.Fatal(err) 210 } 211 defer tls.Free() 212 213 t.Logf("Connecting to %s", u.Host) 214 215 if err := tls.Connect(u.Host, ""); err != nil { 216 t.Fatal(err) 217 } 218 defer func() { 219 if err := tls.Close(); err != nil { 220 t.Fatalf("Close failed: %v", err) 221 } 222 }() 223 224 // All of these should fail since the handshake has not completed. 225 if _, err := tls.ConnVersion(); err == nil { 226 t.Error("ConnVersion() return nil error, want error") 227 } 228 if _, err := tls.ConnCipher(); err == nil { 229 t.Error("ConnCipher() return nil error, want error") 230 } 231 232 if got, want := tls.PeerCertProvided(), false; got != want { 233 t.Errorf("PeerCertProvided() = %v, want %v", got, want) 234 } 235 for _, name := range []string{"127.0.0.1", "::1", "example.com"} { 236 if got, want := tls.PeerCertContainsName(name), false; got != want { 237 t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want) 238 } 239 } 240 241 if _, err := tls.PeerCertIssuer(); err == nil { 242 t.Error("PeerCertIssuer() returned nil error, want error") 243 } 244 if _, err := tls.PeerCertSubject(); err == nil { 245 t.Error("PeerCertSubject() returned nil error, want error") 246 } 247 if _, err := tls.PeerCertHash(); err == nil { 248 t.Error("PeerCertHash() returned nil error, want error") 249 } 250 if _, err := tls.PeerCertNotBefore(); err == nil { 251 t.Error("PeerCertNotBefore() returned nil error, want error") 252 } 253 if _, err := tls.PeerCertNotAfter(); err == nil { 254 t.Error("PeerCertNotAfter() returned nil error, want error") 255 } 256 257 // Complete the handshake... 258 if err := tls.Handshake(); err != nil { 259 t.Fatalf("Handshake failed: %v", err) 260 } 261 262 if version, err := tls.ConnVersion(); err != nil { 263 t.Errorf("ConnVersion() return error: %v", err) 264 } else { 265 t.Logf("Protocol version: %v", version) 266 } 267 if cipher, err := tls.ConnCipher(); err != nil { 268 t.Errorf("ConnCipher() return error: %v", err) 269 } else { 270 t.Logf("Cipher: %v", cipher) 271 } 272 273 if got, want := tls.PeerCertProvided(), true; got != want { 274 t.Errorf("PeerCertProvided() = %v, want %v", got, want) 275 } 276 for _, name := range []string{"127.0.0.1", "::1", "example.com"} { 277 if got, want := tls.PeerCertContainsName(name), true; got != want { 278 t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want) 279 } 280 } 281 282 if issuer, err := tls.PeerCertIssuer(); err != nil { 283 t.Errorf("PeerCertIssuer() returned error: %v", err) 284 } else { 285 t.Logf("Issuer: %v", issuer) 286 } 287 if subject, err := tls.PeerCertSubject(); err != nil { 288 t.Errorf("PeerCertSubject() returned error: %v", err) 289 } else { 290 t.Logf("Subject: %v", subject) 291 } 292 if hash, err := tls.PeerCertHash(); err != nil { 293 t.Errorf("PeerCertHash() returned error: %v", err) 294 } else if hash != certHash { 295 t.Errorf("Got cert hash %q, want %q", hash, certHash) 296 } else { 297 t.Logf("Hash: %v", hash) 298 } 299 if notBefore, err := tls.PeerCertNotBefore(); err != nil { 300 t.Errorf("PeerCertNotBefore() returned error: %v", err) 301 } else if !certNotBefore.Equal(notBefore) { 302 t.Errorf("Got cert notBefore %v, want %v", notBefore.UTC(), certNotBefore.UTC()) 303 } else { 304 t.Logf("NotBefore: %v", notBefore.UTC()) 305 } 306 if notAfter, err := tls.PeerCertNotAfter(); err != nil { 307 t.Errorf("PeerCertNotAfter() returned error: %v", err) 308 } else if !certNotAfter.Equal(notAfter) { 309 t.Errorf("Got cert notAfter %v, want %v", notAfter.UTC(), certNotAfter.UTC()) 310 } else { 311 t.Logf("NotAfter: %v", notAfter.UTC()) 312 } 313} 314