关于airflow:Airflow-从入门到精通03完整-ETL-实例

68次阅读

共计 13979 个字符,预计需要花费 35 分钟才能阅读完成。

本节将讲述应用 Connection、MyqLOperator、XComs 来实现一个残缺的 airflow ETL。

一、将数据存入数据库的原始办法

1、创立表

CREATE database demodb;

use demodb;

create table stock_prices_stage(ticker varchar(30),
as_of_date date,
open_price double,
high_price double,
low_price double,
close_price double

)  COMMENT = '股票价格缓冲区表';

create table stock_prices(
id int not null AUTO_INCREMENT,
ticker varchar(30),
as_of_date date  COMMENT '以后日期',
open_price double,
high_price double,
low_price double,
close_price double,
created_at timestamp default now(),
updated_at timestamp default now(),
primary key (id)
)COMMENT = '股票价格表';

create index ids_stockprices on stock_prices(ticker, as_of_date);

create index ids_stockpricestage on stock_prices_stage(ticker, as_of_date);

二、应用 airflow Connection 治理数据库连贯信息

在上一节代码的根底上,将保留到文件的数据转存到数据库中,V2 版本的代码如下:

download_stock_price_v2.py

2.1 传统连贯办法

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector


def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for" + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量 Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则应用配置参数的数据 (Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    # 连贯数据库
    mydb = mysql.connector.connect(
        host="98.14.13.15",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {'owner': 'airflow'}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """This DAG download stock price"""

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    download_task >> save_to_mysql_task

而后在 airflow 后盾手动触发执行,前两次执行失败,后边调试后,执行胜利了

能够看到数据曾经入库了:

2.2 airflow Connection 治理连贯信息

上边的 demo 有些问题,将数据库的连贯间接硬编码到代码中了,这样前期保护不是很好,airflow 给咱们提供了 Connections 连贯办法,能够应用该办法将连贯信息间接写入到这里即可。

抉择连贯类型,短少了 MySQL 连贯类型:

Conn Type missing? Make sure you've installed the corresponding Airflow Provider Package.

请看官网文档:
https://airflow.apache.org/do…
https://airflow.apache.org/do…
https://airflow.apache.org/do…

$ pip install apache-airflow-providers-mysql

而后从新刷新连贯页面,能够看到连贯类型 MySQL 曾经呈现了:

而后填入相干的数据库连贯信息:

而后对代码进行批改:

def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连贯数据库 (硬编码方式连贯)
    mydb = mysql.connector.connect(
        host="98.14.14.145",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )
    """

    # 应用 airflow 的 Connections 动静获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")

残缺代码:

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from airflow.models import Variable
import mysql.connector


def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for" + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量 Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则应用配置参数的数据 (Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连贯数据库 (硬编码方式连贯)
    mydb = mysql.connector.connect(
        host="98.14.13.14",
        user="root",
        password="Quan888",
        database="demodb",
        port=3307
    )
    """

    # 应用 airflow 的 Connections 动静获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {'owner': 'airflow'}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """This DAG download stock price"""

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    download_task >> save_to_mysql_task

三、应用 MyqLOperator 执行数据库操作

dags/ 目录下新建 sql 文件,用来合并缓冲表(stage)的数据到正式表。

merge_stock_price.sql

-- update the existing rows
UPDATE stock_prices p, stock_prices_stage s
SET p.open_price = s.open_price,
    p.high_price = s.high_price,
        p.low_price = s.low_price,
        p.close_price = s.close_price,
        updated_at = now()
WHERE p.ticker = s.ticker
AND p.as_of_date = s.as_of_date;

-- inserting new rows
INSERT INTO stock_prices
(ticker,as_of_date,open_price,high_price,low_price,close_price)
SELECT ticker,as_of_date,open_price,high_price,low_price,close_price
FROM stock_prices_stage s
WHERE NOT EXISTS
(SELECT 1 FROM stock_prices p
  WHERE p.ticker = s.ticker
    AND p.as_of_date = s.as_of_date);

-- truncate the stage table;
TRUNCATE TABLE stock_prices_stage;

download_stock_price_v2.py 文件新建 MySQL task 工作:
须要先引入:

from airflow.providers.mysql.operators.mysql import MySqlOperator
 mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

残缺代码:

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector

from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator

from airflow.utils.dates import days_ago
from airflow.models import Variable

def download_price(*args, **context):
    stock_list = get_tickers(context)
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for" + ticker)


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量 Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则应用配置参数的数据 (Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    tickers = get_tickers(context)

    """
    # 连贯数据库 (硬编码方式连贯)
    mydb = mysql.connector.connect(
        host="98.14.14.15",
        user="root",
        password="Quan888",
        database="demodb",
        port=3307
    )
    """

    # 应用 airflow 的 Connections 动静获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {'owner': 'airflow'}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """This DAG download stock price"""

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

而后手动执行 airflow,能够看到曾经执行胜利了:

而后看相干表数据,也曾经更新胜利了

四、应用 XComs 在工作之间传递数据

XComs 概念

XComs(“穿插通信”的缩写)是一种让工作互相通信的机制,因为默认状况下工作是齐全隔离的,并且可能运行在齐全不同的机器上。

XCom 由一个键(实质上是它的名称)以及它来自的 task_id 和 dag_id 标识。它们能够具备任何(可序列化的)值,但它们仅实用于大量数据;不要应用它们来传递大值,例如数据帧。

简略一句话,XComs 能够在多个 task 之间进行通信(数据的传递)

XComs are explicitly “pushed” and “pulled” to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.

# Pulls the return_value XCOM from "pushing_task"
value = task_instance.xcom_pull(task_ids='pushing_task')

实战利用

应用场景:减少一支不存在股票,而后对这只股票进行验证,存在的股票才能够传入到后边。

批改 download_stock_price_v2.py 文件下载代码:

而后将股票保留到 MySQL stage 时,通过上一步返回的股票来获取曾经过滤的 ticker。

download_stock_price_v2.py 残缺代码

"""Example DAG demonstrating the usage of the BashOperator."""

from datetime import timedelta
from textwrap import dedent
import yfinance as yf
import mysql.connector

from airflow import DAG
from airflow.operators.python import PythonOperator
# from airflow.operators.mysql_operator import MySqlOperator
from airflow.providers.mysql.operators.mysql import MySqlOperator

from airflow.utils.dates import days_ago
from airflow.models import Variable



def download_price(*args, **context):
    stock_list = get_tickers(context)

    # 新增失常的股票(没有退市的或不存在的)valid_tickers = []
    for ticker in stock_list:
        dat = yf.Ticker(ticker)
        hist = dat.history(period="1mo")
        # print(type(hist))
        # print(hist.shape)
        # print(os.getcwd())

        if hist.shape[0] > 0:
            valid_tickers.append(ticker)
        else:
            continue

        with open(get_file_path(ticker), 'w') as writer:
            hist.to_csv(writer, index=True)
        print("Finished downloading price data for" + ticker)
    # 减少返回值(用于工作之间数据的传递)return valid_tickers


def get_file_path(ticker):
    # NOT SAVE in distributed system
    return f'./{ticker}.csv'


def load_price_data(ticker):
    with open(get_file_path(ticker), 'r') as reader:
        lines = reader.readlines()
        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']


def get_tickers(context):
    # 获取配置的变量 Variables
    stock_list = Variable.get("stock_list_json", deserialize_json=True)

    # 如果有配置参数,则应用配置参数的数据 (Trigger DAG with parameters)
    stocks = context["dag_run"].conf.get("stocks")
    if stocks:
        stock_list = stocks
    return stock_list


def save_to_mysql_stage(*args, **context):
    # tickers = get_tickers(context)
    # Pull the return_value XCOM from "pulling_task"
    tickers = context['ti'].xcom_pull(task_ids='download_prices')
    print(f"received tickers:{tickers}")

    """
    # 连贯数据库 (硬编码方式连贯)
    mydb = mysql.connector.connect(
        host="98.14.14.15",
        user="root",
        password="Quant888",
        database="demodb",
        port=3307
    )
    """

    # 应用 airflow 的 Connections 动静获取配置信息
    from airflow.hooks.base_hook import BaseHook
    conn = BaseHook.get_connection('demodb')

    mydb = mysql.connector.connect(
        host=conn.host,
        user=conn.login,
        password=conn.password,
        database=conn.schema,
        port=conn.port
    )

    mycursor = mydb.cursor()
    for ticker in tickers:
        val = load_price_data(ticker)
        print(f"{ticker} length={len(val)} {val[1]}")

        sql = """INSERT INTO stock_prices_stage
        (ticker, as_of_date, open_price, high_price, low_price, close_price)
        VALUES (%s,%s,%s,%s,%s,%s)"""
        mycursor.executemany(sql, val)

        mydb.commit()
        print(mycursor.rowcount, "record inserted.")


default_args = {'owner': 'airflow'}

# [START instantiate_dag]
with DAG(
        dag_id='download_stock_price_v2',
        default_args=default_args,
        description='download stock price and save to local csv files and save to database',
        schedule_interval=None,
        start_date=days_ago(2),
        tags=['quantdata'],
) as dag:
    # [END instantiate_dag]

    dag.doc_md = """This DAG download stock price"""

    download_task = PythonOperator(
        task_id="download_prices",
        python_callable=download_price,
        provide_context=True
    )

    save_to_mysql_task = PythonOperator(
        task_id="save_to_database",
        python_callable=save_to_mysql_stage,
        provide_context=True
    )

    mysql_task = MySqlOperator(
        task_id="merge_stock_price",
        mysql_conn_id='demodb',
        sql="merge_stock_price.sql",
        dag=dag,
    )

    download_task >> save_to_mysql_task >> mysql_task

而后在 Variables 减少一个不存在的 ticker(FBXXOO),以此来验证 Xcom 数据传递进行验证:

手动执行 DAG,能够通过日志打印看到曾经获取到了 Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
上一个工作传递过去的数据了。


相干文章:
Airflow 相干概念文档
Airflow XComs 官网文档

正文完
 0