本节将讲述应用 Connection、MyqLOperator、XComs 来实现一个残缺的airflow ETL。

一、将数据存入数据库的原始办法

1、创立表

CREATE database demodb;use demodb;create table stock_prices_stage(ticker varchar(30),as_of_date date,open_price double,high_price double,low_price double,close_price double)  COMMENT = '股票价格缓冲区表';create table stock_prices(id int not null AUTO_INCREMENT,ticker varchar(30),as_of_date date  COMMENT '以后日期',open_price double,high_price double,low_price double,close_price double,created_at timestamp default now(),updated_at timestamp default now(),primary key (id))COMMENT = '股票价格表';create index ids_stockprices on stock_prices(ticker, as_of_date);create index ids_stockpricestage on stock_prices_stage(ticker, as_of_date);

二、应用 airflow Connection 治理数据库连贯信息

在上一节代码的根底上,将保留到文件的数据转存到数据库中,V2版本的代码如下:

download_stock_price_v2.py

2.1 传统连贯办法

"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yffrom airflow import DAGfrom airflow.operators.python import PythonOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variableimport mysql.connectordef download_price(*args, **context):    stock_list = get_tickers(context)    for ticker in stock_list:        dat = yf.Ticker(ticker)        hist = dat.history(period="1mo")        # print(type(hist))        # print(hist.shape)        # print(os.getcwd())        with open(get_file_path(ticker), 'w') as writer:            hist.to_csv(writer, index=True)        print("Finished downloading price data for " + ticker)def get_file_path(ticker):    # NOT SAVE in distributed system    return f'./{ticker}.csv'def load_price_data(ticker):    with open(get_file_path(ticker), 'r') as reader:        lines = reader.readlines()        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context):    # 获取配置的变量Variables    stock_list = Variable.get("stock_list_json", deserialize_json=True)    # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters)    stocks = context["dag_run"].conf.get("stocks")    if stocks:        stock_list = stocks    return stock_listdef save_to_mysql_stage(*args, **context):    tickers = get_tickers(context)    # 连贯数据库    mydb = mysql.connector.connect(        host="98.14.13.15",        user="root",        password="Quant888",        database="demodb",        port=3307    )    mycursor = mydb.cursor()    for ticker in tickers:        val = load_price_data(ticker)        print(f"{ticker} length={len(val)} {val[1]}")        sql = """INSERT INTO stock_prices_stage        (ticker, as_of_date, open_price, high_price, low_price, close_price)        VALUES (%s,%s,%s,%s,%s,%s)"""        mycursor.executemany(sql, val)        mydb.commit()        print(mycursor.rowcount, "record inserted.")default_args = {    'owner': 'airflow'}# [START instantiate_dag]with DAG(        dag_id='download_stock_price_v2',        default_args=default_args,        description='download stock price and save to local csv files and save to database',        schedule_interval=None,        start_date=days_ago(2),        tags=['quantdata'],) as dag:    # [END instantiate_dag]    dag.doc_md = """    This DAG download stock price    """    download_task = PythonOperator(        task_id="download_prices",        python_callable=download_price,        provide_context=True    )    save_to_mysql_task = PythonOperator(        task_id="save_to_database",        python_callable=save_to_mysql_stage,        provide_context=True    )    download_task >> save_to_mysql_task

而后在 airflow 后盾手动触发执行,前两次执行失败,后边调试后,执行胜利了

能够看到数据曾经入库了:

2.2 airflow Connection治理连贯信息

上边的demo有些问题,将数据库的连贯间接硬编码到代码中了,这样前期保护不是很好,airflow给咱们提供了 Connections 连贯办法,能够应用该办法将连贯信息间接写入到这里即可。

抉择连贯类型,短少了MySQL连贯类型:

Conn Type missing? Make sure you've installed the corresponding Airflow Provider Package.

请看官网文档:
https://airflow.apache.org/do...
https://airflow.apache.org/do...
https://airflow.apache.org/do...

$ pip install apache-airflow-providers-mysql

而后从新刷新连贯页面,能够看到连贯类型 MySQL 曾经呈现了:

而后填入相干的数据库连贯信息:

而后对代码进行批改:

def save_to_mysql_stage(*args, **context):    tickers = get_tickers(context)    """    # 连贯数据库(硬编码方式连贯)    mydb = mysql.connector.connect(        host="98.14.14.145",        user="root",        password="Quant888",        database="demodb",        port=3307    )    """    # 应用airflow 的 Connections 动静获取配置信息    from airflow.hooks.base_hook import BaseHook    conn = BaseHook.get_connection('demodb')    mydb = mysql.connector.connect(        host=conn.host,        user=conn.login,        password=conn.password,        database=conn.schema,        port=conn.port    )    mycursor = mydb.cursor()    for ticker in tickers:        val = load_price_data(ticker)        print(f"{ticker} length={len(val)} {val[1]}")        sql = """INSERT INTO stock_prices_stage        (ticker, as_of_date, open_price, high_price, low_price, close_price)        VALUES (%s,%s,%s,%s,%s,%s)"""        mycursor.executemany(sql, val)        mydb.commit()        print(mycursor.rowcount, "record inserted.")

残缺代码:

"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yffrom airflow import DAGfrom airflow.operators.python import PythonOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variableimport mysql.connectordef download_price(*args, **context):    stock_list = get_tickers(context)    for ticker in stock_list:        dat = yf.Ticker(ticker)        hist = dat.history(period="1mo")        # print(type(hist))        # print(hist.shape)        # print(os.getcwd())        with open(get_file_path(ticker), 'w') as writer:            hist.to_csv(writer, index=True)        print("Finished downloading price data for " + ticker)def get_file_path(ticker):    # NOT SAVE in distributed system    return f'./{ticker}.csv'def load_price_data(ticker):    with open(get_file_path(ticker), 'r') as reader:        lines = reader.readlines()        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context):    # 获取配置的变量Variables    stock_list = Variable.get("stock_list_json", deserialize_json=True)    # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters)    stocks = context["dag_run"].conf.get("stocks")    if stocks:        stock_list = stocks    return stock_listdef save_to_mysql_stage(*args, **context):    tickers = get_tickers(context)    """    # 连贯数据库(硬编码方式连贯)    mydb = mysql.connector.connect(        host="98.14.13.14",        user="root",        password="Quan888",        database="demodb",        port=3307    )    """    # 应用airflow 的 Connections 动静获取配置信息    from airflow.hooks.base_hook import BaseHook    conn = BaseHook.get_connection('demodb')    mydb = mysql.connector.connect(        host=conn.host,        user=conn.login,        password=conn.password,        database=conn.schema,        port=conn.port    )    mycursor = mydb.cursor()    for ticker in tickers:        val = load_price_data(ticker)        print(f"{ticker} length={len(val)} {val[1]}")        sql = """INSERT INTO stock_prices_stage        (ticker, as_of_date, open_price, high_price, low_price, close_price)        VALUES (%s,%s,%s,%s,%s,%s)"""        mycursor.executemany(sql, val)        mydb.commit()        print(mycursor.rowcount, "record inserted.")default_args = {    'owner': 'airflow'}# [START instantiate_dag]with DAG(        dag_id='download_stock_price_v2',        default_args=default_args,        description='download stock price and save to local csv files and save to database',        schedule_interval=None,        start_date=days_ago(2),        tags=['quantdata'],) as dag:    # [END instantiate_dag]    dag.doc_md = """    This DAG download stock price    """    download_task = PythonOperator(        task_id="download_prices",        python_callable=download_price,        provide_context=True    )    save_to_mysql_task = PythonOperator(        task_id="save_to_database",        python_callable=save_to_mysql_stage,        provide_context=True    )    download_task >> save_to_mysql_task

三、应用 MyqLOperator 执行数据库操作

dags/ 目录下新建sql文件,用来合并缓冲表(stage)的数据到正式表。

merge_stock_price.sql

-- update the existing rowsUPDATE stock_prices p, stock_prices_stage sSET p.open_price = s.open_price,    p.high_price = s.high_price,        p.low_price = s.low_price,        p.close_price = s.close_price,        updated_at = now()WHERE p.ticker = s.tickerAND p.as_of_date = s.as_of_date;-- inserting new rowsINSERT INTO stock_prices(ticker,as_of_date,open_price,high_price,low_price,close_price)SELECT ticker,as_of_date,open_price,high_price,low_price,close_priceFROM stock_prices_stage sWHERE NOT EXISTS(SELECT 1 FROM stock_prices p  WHERE p.ticker = s.ticker    AND p.as_of_date = s.as_of_date);-- truncate the stage table;TRUNCATE TABLE stock_prices_stage;

download_stock_price_v2.py 文件新建 MySQL task 工作:
须要先引入:

from airflow.providers.mysql.operators.mysql import MySqlOperator
 mysql_task = MySqlOperator(        task_id="merge_stock_price",        mysql_conn_id='demodb',        sql="merge_stock_price.sql",        dag=dag,    )    download_task >> save_to_mysql_task >> mysql_task

残缺代码:

"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yfimport mysql.connectorfrom airflow import DAGfrom airflow.operators.python import PythonOperator# from airflow.operators.mysql_operator import MySqlOperatorfrom airflow.providers.mysql.operators.mysql import MySqlOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variabledef download_price(*args, **context):    stock_list = get_tickers(context)    for ticker in stock_list:        dat = yf.Ticker(ticker)        hist = dat.history(period="1mo")        # print(type(hist))        # print(hist.shape)        # print(os.getcwd())        with open(get_file_path(ticker), 'w') as writer:            hist.to_csv(writer, index=True)        print("Finished downloading price data for " + ticker)def get_file_path(ticker):    # NOT SAVE in distributed system    return f'./{ticker}.csv'def load_price_data(ticker):    with open(get_file_path(ticker), 'r') as reader:        lines = reader.readlines()        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context):    # 获取配置的变量Variables    stock_list = Variable.get("stock_list_json", deserialize_json=True)    # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters)    stocks = context["dag_run"].conf.get("stocks")    if stocks:        stock_list = stocks    return stock_listdef save_to_mysql_stage(*args, **context):    tickers = get_tickers(context)    """    # 连贯数据库(硬编码方式连贯)    mydb = mysql.connector.connect(        host="98.14.14.15",        user="root",        password="Quan888",        database="demodb",        port=3307    )    """    # 应用airflow 的 Connections 动静获取配置信息    from airflow.hooks.base_hook import BaseHook    conn = BaseHook.get_connection('demodb')    mydb = mysql.connector.connect(        host=conn.host,        user=conn.login,        password=conn.password,        database=conn.schema,        port=conn.port    )    mycursor = mydb.cursor()    for ticker in tickers:        val = load_price_data(ticker)        print(f"{ticker} length={len(val)} {val[1]}")        sql = """INSERT INTO stock_prices_stage        (ticker, as_of_date, open_price, high_price, low_price, close_price)        VALUES (%s,%s,%s,%s,%s,%s)"""        mycursor.executemany(sql, val)        mydb.commit()        print(mycursor.rowcount, "record inserted.")default_args = {    'owner': 'airflow'}# [START instantiate_dag]with DAG(        dag_id='download_stock_price_v2',        default_args=default_args,        description='download stock price and save to local csv files and save to database',        schedule_interval=None,        start_date=days_ago(2),        tags=['quantdata'],) as dag:    # [END instantiate_dag]    dag.doc_md = """    This DAG download stock price    """    download_task = PythonOperator(        task_id="download_prices",        python_callable=download_price,        provide_context=True    )    save_to_mysql_task = PythonOperator(        task_id="save_to_database",        python_callable=save_to_mysql_stage,        provide_context=True    )    mysql_task = MySqlOperator(        task_id="merge_stock_price",        mysql_conn_id='demodb',        sql="merge_stock_price.sql",        dag=dag,    )    download_task >> save_to_mysql_task >> mysql_task

而后手动执行airflow,能够看到曾经执行胜利了:

而后看相干表数据,也曾经更新胜利了

四、应用 XComs 在工作之间传递数据

XComs 概念

XComs(“穿插通信”的缩写)是一种让工作互相通信的机制,因为默认状况下工作是齐全隔离的,并且可能运行在齐全不同的机器上。

XCom 由一个键(实质上是它的名称)以及它来自的 task_id 和 dag_id 标识。它们能够具备任何(可序列化的)值,但它们仅实用于大量数据;不要应用它们来传递大值,例如数据帧。

简略一句话,XComs能够在多个task之间进行通信(数据的传递)

XComs are explicitly "pushed" and "pulled" to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.

# Pulls the return_value XCOM from "pushing_task"value = task_instance.xcom_pull(task_ids='pushing_task')

实战利用

应用场景:减少一支不存在股票,而后对这只股票进行验证,存在的股票才能够传入到后边。

批改 download_stock_price_v2.py 文件下载代码:

而后将股票保留到MySQL stage 时,通过上一步返回的股票来获取曾经过滤的ticker。

download_stock_price_v2.py 残缺代码

"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yfimport mysql.connectorfrom airflow import DAGfrom airflow.operators.python import PythonOperator# from airflow.operators.mysql_operator import MySqlOperatorfrom airflow.providers.mysql.operators.mysql import MySqlOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variabledef download_price(*args, **context):    stock_list = get_tickers(context)    # 新增失常的股票(没有退市的或不存在的)    valid_tickers = []    for ticker in stock_list:        dat = yf.Ticker(ticker)        hist = dat.history(period="1mo")        # print(type(hist))        # print(hist.shape)        # print(os.getcwd())        if hist.shape[0] > 0:            valid_tickers.append(ticker)        else:            continue        with open(get_file_path(ticker), 'w') as writer:            hist.to_csv(writer, index=True)        print("Finished downloading price data for " + ticker)    # 减少返回值(用于工作之间数据的传递)    return valid_tickersdef get_file_path(ticker):    # NOT SAVE in distributed system    return f'./{ticker}.csv'def load_price_data(ticker):    with open(get_file_path(ticker), 'r') as reader:        lines = reader.readlines()        return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context):    # 获取配置的变量Variables    stock_list = Variable.get("stock_list_json", deserialize_json=True)    # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters)    stocks = context["dag_run"].conf.get("stocks")    if stocks:        stock_list = stocks    return stock_listdef save_to_mysql_stage(*args, **context):    # tickers = get_tickers(context)    # Pull the return_value XCOM from "pulling_task"    tickers = context['ti'].xcom_pull(task_ids='download_prices')    print(f"received tickers:{tickers}")    """    # 连贯数据库(硬编码方式连贯)    mydb = mysql.connector.connect(        host="98.14.14.15",        user="root",        password="Quant888",        database="demodb",        port=3307    )    """    # 应用airflow 的 Connections 动静获取配置信息    from airflow.hooks.base_hook import BaseHook    conn = BaseHook.get_connection('demodb')    mydb = mysql.connector.connect(        host=conn.host,        user=conn.login,        password=conn.password,        database=conn.schema,        port=conn.port    )    mycursor = mydb.cursor()    for ticker in tickers:        val = load_price_data(ticker)        print(f"{ticker} length={len(val)} {val[1]}")        sql = """INSERT INTO stock_prices_stage        (ticker, as_of_date, open_price, high_price, low_price, close_price)        VALUES (%s,%s,%s,%s,%s,%s)"""        mycursor.executemany(sql, val)        mydb.commit()        print(mycursor.rowcount, "record inserted.")default_args = {    'owner': 'airflow'}# [START instantiate_dag]with DAG(        dag_id='download_stock_price_v2',        default_args=default_args,        description='download stock price and save to local csv files and save to database',        schedule_interval=None,        start_date=days_ago(2),        tags=['quantdata'],) as dag:    # [END instantiate_dag]    dag.doc_md = """    This DAG download stock price    """    download_task = PythonOperator(        task_id="download_prices",        python_callable=download_price,        provide_context=True    )    save_to_mysql_task = PythonOperator(        task_id="save_to_database",        python_callable=save_to_mysql_stage,        provide_context=True    )    mysql_task = MySqlOperator(        task_id="merge_stock_price",        mysql_conn_id='demodb',        sql="merge_stock_price.sql",        dag=dag,    )    download_task >> save_to_mysql_task >> mysql_task

而后在 Variables 减少一个不存在的 ticker(FBXXOO),以此来验证Xcom数据传递进行验证:

手动执行DAG,能够通过日志打印看到曾经获取到了 Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
上一个工作传递过去的数据了。


相干文章:
Airflow 相干概念文档
Airflow XComs官网文档