You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.6 KiB

import os
import pytest
import random
import dask.bag as db
fastavro = pytest.importorskip('fastavro')
expected = [{'name': random.choice(['fred', 'wilma', 'barney', 'betty']),
'number': random.randint(0, 100)} for _ in range(1000)]
schema = {
'doc': 'Descr',
'name': 'Random',
'namespace': 'test',
'type': 'record',
'fields': [
{'name': 'name', 'type': 'string'},
{'name': 'number', 'type': 'int'},
def test_onefile_oneblock(tmpdir):
fn = os.path.join(tmpdir, 'one.avro')
with open(fn, 'wb') as f:
fastavro.writer(f, records=expected, schema=schema)
b = db.read_avro(fn, blocksize=None)
assert b.npartitions == 1
assert b.compute() == expected
def test_twofile_oneblock(tmpdir):
fn1 = os.path.join(tmpdir, 'one.avro')
fn2 = os.path.join(tmpdir, 'two.avro')
with open(fn1, 'wb') as f:
fastavro.writer(f, records=expected[:500], schema=schema)
with open(fn2, 'wb') as f:
fastavro.writer(f, records=expected[500:], schema=schema)
b = db.read_avro(os.path.join(tmpdir, '*.avro'), blocksize=None)
assert b.npartitions == 2
assert b.compute() == expected
def test_twofile_multiblock(tmpdir):
fn1 = os.path.join(tmpdir, 'one.avro')
fn2 = os.path.join(tmpdir, 'two.avro')
with open(fn1, 'wb') as f:
fastavro.writer(f, records=expected[:500], schema=schema,
with open(fn2, 'wb') as f:
fastavro.writer(f, records=expected[500:], schema=schema,
b = db.read_avro(os.path.join(tmpdir, '*.avro'), blocksize=None)
assert b.npartitions == 2
assert b.compute() == expected
b = db.read_avro(os.path.join(tmpdir, '*.avro'), blocksize=1000)
assert b.npartitions > 2
assert b.compute() == expected
def test_roundtrip_simple(tmpdir):
from dask.delayed import Delayed
fn = os.path.join(tmpdir, 'out*.avro')
b = db.from_sequence([{'a': i} for i in [1, 2, 3, 4, 5]], npartitions=2)
schema = {
'name': 'Test',
'type': 'record',
'fields': [
{'name': 'a', 'type': 'int'}, ]}
out = b.to_avro(fn, schema, compute=False)
assert isinstance(out[0], Delayed)
out = b.to_avro(fn, schema)
assert len(out) == 2
b2 = db.read_avro(fn)
assert b.compute() == b2.compute()
@pytest.mark.parametrize('codec', ['null', 'deflate', 'snappy'])
def test_roundtrip(tmpdir, codec):
if codec == 'snappy':
fn = os.path.join(tmpdir, 'out*.avro')
b = db.from_sequence(expected, npartitions=3)
b.to_avro(fn, schema=schema, codec=codec)
b2 = db.read_avro(fn)
assert b.compute() == b2.compute()
def test_invalid_schema(tmpdir):
b = db.from_sequence(expected, npartitions=3)
fn = os.path.join(tmpdir, 'out*.avro')
with pytest.raises(AssertionError):
b.to_avro(fn, schema=[])
with pytest.raises(AssertionError):
b.to_avro(fn, schema={})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'doc': 'unknown'})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'name': 'test'})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'name': 'test', 'type': 'wrong'})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'name': 'test', 'type': 'record'})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'name': 'test', 'type': 'record'})
with pytest.raises(AssertionError):
b.to_avro(fn, schema={'name': 'test', 'type': 'record',
'fields': [{'name': 'a'}]})