Mocking AWS S3 Scanner with Golang

tufin.NewScanner("us-east-2", ".").Scan("my-bucket", func(file *os.File) {
log.Info(file.Name())
})

S3Scanner Implementation

func NewS3Scanner(region string, dir string) *S3Scanner {

// verify aws auth
verifyEnv("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")

awsSession, err := session.NewSession(&aws.Config{
Region: aws.String(region)},
)
if err != nil {
log.Fatalf("failed to create aws session to region '%s' with '%v'", region, err)
}

s3Svc := s3.New(awsSession)

return &S3Scanner{
svc: s3Svc,
downloader: s3manager.NewDownloaderWithClient(s3Svc),
tmpDirFilesystem: dir}
}
type S3Scanner struct {
svc s3iface.S3API
downloader *s3manager.Downloader
tmpDirFilesystem string
}
func verifyEnv(keys ...string) {

for _, currKey := range keys {
if val := os.Getenv(currKey); val == "" {
log.Fatalf("Please, set '%s'", val)
}
}
}
func (fi S3Scanner) Scan(bucket string, callback func(*os.File)) {

response, err := fi.svc.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(bucket),
})
if err != nil {
log.Fatalf("failed to get list of s3 files from bucket '%s' with '%v'", bucket, err)
}

for _, currObj := range response.Contents {
currS3FilePath := *currObj.Key
currS3FileName := currS3FilePath[strings.LastIndex(currS3FilePath, "/")+1:]
currFilePath := fmt.Sprintf("%s/%s", fi.tmpDirFilesystem, currS3FileName)
callback(download(fi.downloader, bucket, currObj.Key, currS3FilePath, currFilePath))
deleteFromFilesystem(currFilePath)
}
}
func download(downloader *s3manager.Downloader, bucket string, key *string, s3FilePath string, path string) *os.File {

log.Debugf("Creating a file '%s'...", path)
ret, err := os.Create(path)
if err != nil {
log.Fatalf("failed to create file '%s'", path)
}

log.Infof("Downloading file '%s' from s3...", path)
_, err = downloader.Download(ret, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: key,
})
if err != nil {
log.Fatalf("unable to download '%s' from bucket '%s' with '%v'", s3FilePath, bucket, err)
}

return ret
}

The Unit Tests

type mockS3Client struct {
s3iface.S3API
}

func (m *mockS3Client) ListObjects(_ *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {

return &s3.ListObjectsOutput{
Contents: []*s3.Object{{Key: aws.String(key)}},
}, nil
}
func getDownloader() *s3manager.Downloader {

var locker sync.Mutex
payload := []byte(data)

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
locker.Lock()
defer locker.Unlock()

r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(payload)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", "1")
})

return s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
}
package tufin

import (
"bufio"
"bytes"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/require"
)

const (
key = "access-logs/10Jul2020/1.log"
data = "This is S3 file mock data"
)

type mockS3Client struct {
s3iface.S3API
}

func (m *mockS3Client) ListObjects(_ *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {

return &s3.ListObjectsOutput{
Contents: []*s3.Object{{Key: aws.String(key)}},
}, nil
}

func TestS3Scanner_Scan(t *testing.T) {

S3Scanner{
svc: &mockS3Client{},
downloader: getDownloader(),
tmpDirFilesystem: ".",
}.Scan("my-bucket", func(file *os.File) {
defer func() { require.NoError(t, file.Close()) }()
path := file.Name()
require.Equal(t, key[strings.LastIndex(key, "/")+1:],
path[strings.LastIndex(path, "/")+1:])
scanner := bufio.NewScanner(file)
scanner.Scan()
require.Equal(t, data, scanner.Text())
})
}

func getDownloader() *s3manager.Downloader {

var locker sync.Mutex
payload := []byte(data)

svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()
svc.Handlers.Send.PushBack(func(r *request.Request) {
locker.Lock()
defer locker.Unlock()

r.HTTPResponse = &http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(payload)),
Header: http.Header{},
}
r.HTTPResponse.Header.Set("Content-Length", "1")
})

return s3manager.NewDownloaderWithClient(svc, func(d *s3manager.Downloader) {
d.Concurrency = 1
d.PartSize = 1
})
}

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Tufin

Tufin

From the Security Policy Company. This blog is dedicated to cloud-native topics such as Kubernetes, cloud security and micro-services.