Mocking AWS S3 Scanner with Golang

By Effi Bar-She’an

Analyzing files on an AWS S3 bucket is a common task with many examples for doing so available on the Internet, however, doing it in a way that enables unit tests, is somewhat of a mystery.

So here’s a complete example of a golang client that:

  1. Downloads each file from an S3 bucket to a local filesystem
  2. Does some work
  3. Deletes the temporary file from your filesystem

And, it is testable!

Let’s take a look at what the client looks like:

tufin.NewScanner("us-east-2", ".").Scan("my-bucket", func(file *os.File) {
log.Info(file.Name())
})

The constructor accepts an AWS region and a local path to use for storing the temporary downloaded S3 files. Scan accepts the bucket to scan and triggers the callback for each downloaded file.

S3Scanner Implementation

We start by creating an S3 service in the S3Scanner constructor:

func NewS3Scanner(region string, dir string) *S3Scanner {

// verify aws auth
verifyEnv("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")

awsSession, err := session.NewSession(&aws.Config{
Region: aws.String(region)},
)
if err != nil {
log.Fatalf("failed to create aws session to region '%s' with '%v'", region, err)
}

s3Svc := s3.New(awsSession)

return &S3Scanner{
svc: s3Svc,
downloader: s3manager.NewDownloaderWithClient(s3Svc),
tmpDirFilesystem: dir}
}

The trick here is to use thes3iface.S3API interface instead of *S3 which is returned by thes3.New function. This makes it testable!

type S3Scanner struct {
svc s3iface.S3API
downloader *s3manager.Downloader
tmpDirFilesystem string
}

Note, that when you initialize a new service client (svc above) without supplying any arguments, the AWS SDK attempts to find AWS credentials by using the default credential provider chain. In our case verifyEnv verifies that AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY have non-empty values:

func verifyEnv(keys ...string) {

for _, currKey := range keys {
if val := os.Getenv(currKey); val == "" {
log.Fatalf("Please, set '%s'", val)
}
}
}

Now, we call ListObjects using the svc that was created in the constructor, iterate the response and pass the files on to our callback.

func (fi S3Scanner) Scan(bucket string, callback func(*os.File)) {

response, err := fi.svc.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(bucket),
})
if err != nil {
log.Fatalf("failed to get list of s3 files from bucket '%s' with '%v'", bucket, err)
}

for _, currObj := range response.Contents {
currS3FilePath := *currObj.Key
currS3FileName := currS3FilePath[strings.LastIndex(currS3FilePath, "/")+1:]
currFilePath := fmt.Sprintf("%s/%s", fi.tmpDirFilesystem, currS3FileName)
callback(download(fi.downloader, bucket, currObj.Key, currS3FilePath, currFilePath))
deleteFromFilesystem(currFilePath)
}
}

Here is AWS Go SDK example, with our trick of using the s3iface.S3API interface to make it testable :)

Let’s take a look at the download function which uses the *s3manager.Downloader that we created in the constructor:

func download(downloader *s3manager.Downloader, bucket string, key *string, s3FilePath string, path string) *os.File {

log.Debugf("Creating a file '%s'...", path)
ret, err := os.Create(path)
if err != nil {
log.Fatalf("failed to create file '%s'", path)
}

log.Infof("Downloading file '%s' from s3...", path)
_, err = downloader.Download(ret, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: key,
})
if err != nil {
log.Fatalf("unable to download '%s' from bucket '%s' with '%v'", s3FilePath, bucket, err)
}

return ret
}

The Unit Tests

As mentioned above, our code is testable. Let’s see how.

We start by mocking S3 ListObjects:

type mockS3Client struct {
s3iface.S3API
}

func (m *mockS3Client) ListObjects(_ *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {

return &s3.ListObjectsOutput{
Contents: []*s3.Object{{Key: aws.String(key)}},
}, nil
}

As you can see our mock inherits from s3iface.S3API and returns an object with akey which represents the file path on S3.

To mock the Download function of s3manager we implement svc.Handlers.Send.PushBack that returns mock S3 file data:

func getDownloader() *s3manager.Downloader {

var locker sync.Mutex
payload := []byte(data)

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
locker.Lock()
defer locker.Unlock()

r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(payload)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", "1")
})

return s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
}

Let’s put all together:

package tufin

import (
"bufio"
"bytes"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/require"
)

const (
key = "access-logs/10Jul2020/1.log"
data = "This is S3 file mock data"
)

type mockS3Client struct {
s3iface.S3API
}

func (m *mockS3Client) ListObjects(_ *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {

return &s3.ListObjectsOutput{
Contents: []*s3.Object{{Key: aws.String(key)}},
}, nil
}

func TestS3Scanner_Scan(t *testing.T) {

S3Scanner{
svc: &mockS3Client{},
downloader: getDownloader(),
tmpDirFilesystem: ".",
}.Scan("my-bucket", func(file *os.File) {
defer func() { require.NoError(t, file.Close()) }()
path := file.Name()
require.Equal(t, key[strings.LastIndex(key, "/")+1:],
path[strings.LastIndex(path, "/")+1:])
scanner := bufio.NewScanner(file)
scanner.Scan()
require.Equal(t, data, scanner.Text())
})
}

func getDownloader() *s3manager.Downloader {

var locker sync.Mutex
payload := []byte(data)

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
locker.Lock()
defer locker.Unlock()

r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(payload)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", "1")
})

return s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
}

See a full example in: https://github.com/Tufin/blog/tree/master/s3-scanner

Reference: Go SDK S3 Docs, AWS Go SDK Examples on GitHub

From the Security Policy Company. This blog is dedicated to cloud-native topics such as Kubernetes, cloud security and micro-services.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store